jax.extend.linear_util.transformation_with_aux

jax.extend.linear_util.transformation_with_aux#

jax.extend.linear_util.transformation_with_aux = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[source]#

为 WrappedFun 添加一个具有辅助输出的变换。

参数:
返回类型:

tuple[WrappedFun, Callable[[], Any]]