jax.extend.linear_util.WrappedFun

jax.extend.linear_util.WrappedFun#

class jax.extend.linear_util.WrappedFun(f, transforms, stores, params, in_type, debug_info)[source]#

表示一个函数 f,对其将应用 transforms

参数:
  • f – 要变换的函数。

  • transforms – 一个 (gen, gen_static_args) 元组列表,表示要应用于 f 的变换。 这里 gen 是一个生成器函数,而 gen_static_args 是一个生成器静态参数的元组。有关生成器的预期行为,请参阅本模块开头的描述。

  • storestransforms 辅助输出的 out_store 列表。

  • params – 要传递给 f 的额外参数,作为关键字参数,以及经过变换的关键字参数。

__init__(f, transforms, stores, params, in_type, debug_info)[source]#

方法

__init__(f, transforms, stores, params, ...)

call_wrapped(*args, **kwargs)

调用底层函数,应用转换。

populate_stores(stores)

stores中的值复制到self.stores中。

wrap(gen, gen_static_args, out_store)

添加另一个转换及其存储。

属性

f

transforms

stores

params

in_type

debug_info