jax.tree_util.Partial

jax.tree_util.Partial#

class jax.tree_util.Partial(func, *args, **kw)#

一个与 pytree 兼容的 functools.partial 版本。

使用它进行部分函数评估,这种方法与 JAX 的转换兼容,例如 Partial(func, *args, **kwargs)

(您需要明确选择此行为,因为我们不想赋予 functools.partial 与普通函数闭包不同的语义。)

例如,以下是如何以类似于 functools.partial 的方式使用 Partial 的基本示例

>>> import jax.numpy as jnp
>>> add_one = Partial(jnp.add, 1)
>>> add_one(2)
Array(3, dtype=int32, weak_type=True)

Pytree 兼容性意味着生成的偏函数可以作为参数传递给已转换的 JAX 函数,而标准的 functools.partial 函数则无法做到这一点

>>> from jax import jit
>>> @jit
... def call_func(f, *args):
...   return f(*args)
...
>>> call_func(add_one, 2)
Array(3, dtype=int32, weak_type=True)

Partial 传递零个参数实际上会包装原始函数,使其成为 JAX 已转换函数中的有效参数

>>> call_func(Partial(jnp.add), 1, 2)
Array(3, dtype=int32, weak_type=True)

如果我们将 jnp.add 直接传递给 call_func,将会导致 TypeError

请注意,如果 Partial 的结果在追踪值的上下文中使用,则会导致所有绑定的参数在传递给部分评估的函数时被追踪。

>>> print_zero = Partial(print, 0)
>>> print_zero()
0
>>> call_func(print_zero)  
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
__init__()#

方法

属性

args

传递给未来部分调用函数的参数元组

func

未来部分调用中要使用的函数对象

keywords

传递给未来部分调用函数的关键字参数字典