jax.extend.linear_util.transformation_with_aux#
- jax.extend.linear_util.transformation_with_aux = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[source]#
为 WrappedFun 添加一个具有辅助输出的变换。
- 参数:
fun (WrappedFun)
use_eq_store (bool)
- 返回类型:
tuple[WrappedFun, Callable[[], Any]]