jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[源代码]#

为以 WrappedFun 作为第一个参数的函数提供的记忆化装饰器。

参数:
  • call (Callable) – 一个 Python 可调用对象,它以 WrappedFun 作为其第一个参数。 WrappedFun 上的底层变换和参数用作记忆化缓存键的一部分。

  • explain (Callable | None | None)

返回:

call 的记忆化版本。