jax.extend.linear_util.transformation

jax.extend.linear_util.transformation#

jax.extend.linear_util.transformation = functools.partial(<class 'functools.partial'>, <function transformation>)[source]#

为 WrappedFun 添加另一个变换。

参数:
  • gen – 变换生成器函数

  • fun (WrappedFun) – 要应用变换的 WrappedFun

  • gen_static_args – 生成器函数的静态参数

返回类型:

WrappedFun