jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[源代码]#
调用纯 Python 回调。在
jit()
/vmap()
/等等下工作。有关更多说明,请参阅外部回调。
pure_callback
允许在 JIT 编译的 JAX 函数中调用 Python 函数。输入的callback
将被传递放置在本地 CPU 上的 JAX 数组,并且它还应该在 CPU 上返回 JAX 数组。该回调被视为功能纯函数,意味着它没有副作用,其输出值仅取决于其参数值。因此,可以安全地多次调用它(例如,当被
vmap()
或pmap()
转换时),或者当例如一个由 jit 修饰的函数的输出与其值没有数据依赖关系时,可以根本不调用它。如果数据依赖性允许,纯回调也可以重新排序。当使用 vmap 时,行为将取决于
vmap_method
的值。在没有显式
vmap_method
的情况下对回调调用vmap()
已被弃用,最终会引发NotImplementedError
。vmap_method="sequential"
使用map()
循环遍历批处理的参数,为每个批处理元素调用一次callback
。vmap_method="expand_dims"
在未批处理的输入的前导维度上添加大小为1
的新轴,并调用callback
。vmap_method="broadcast_all"
的行为类似于expand_dims
,但输入会被平铺到预期的批处理形状。
如有必要,可以使用
vmap_method="legacy_vectorized"
恢复已弃用的vectorized=True
参数提供的旧行为。当前默认行为是在未指定时使用
vmap_method="sequential"
,但此行为已弃用,将来,除非显式指定vmap_method
,否则默认行为将引发NotImplementedError
。- 参数:
callback (Callable[..., Any]) – 在主机上执行的函数。假定回调是一个纯函数(即没有副作用的函数):如果传递一个不纯的函数,它可能会以意想不到的方式运行,尤其是在转换下。该可调用对象将传递数组的 PyTree 作为参数,并应返回一个与
result_shape_dtypes
匹配的数组 PyTree。result_shape_dtypes (Any) – 其叶子具有
shape
和dtype
属性的 pytree,其结构与运行时回调函数的预期输出匹配。jax.ShapeDtypeStruct
通常用于定义叶子值。*args (Any) – 要传递给回调函数的参数
sharding (SingleDeviceSharding | None | None) – 可选的分片,用于指定应从中调用回调的设备。
vmap_method (str | None | None) – 字符串,指定回调在
vmap()
下如何转换,如上所述。**kwargs (Any) – 要传递给回调函数的关键字参数
vectorized (bool | None | DeprecatedArg)
- 返回:
- 一个
jax.Array
对象的 pytree,其结构与 result_shape_dtypes 的结构相匹配
.
- 一个
- 返回类型:
结果
另请参阅
jax.experimental.io_callback()
:专为不纯函数设计的回调。jax.debug.callback()
:专为通用调试设计的回调。jax.debug.print()
:专为打印设计的回调。
示例
如上所述,
pure_callback
在vmap()
下的行为由vmap_method
参数控制。考虑一些明确的示例来演示语义是很有用的。例如,考虑以下函数>>> def callback(x, y): ... print(jnp.shape(x), jnp.shape(y)) ... return x + y
>>> def fun(x, y, *, vmap_method): ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) ... dtype = jnp.result_type(x, y) ... out_type = jax.ShapeDtypeStruct(shape, dtype) ... return jax.pure_callback(callback, out_type, x, y, ... vmap_method=vmap_method)
使用
vmap_method="expand_dims"
调用此函数会在y
中添加一个大小为1
的新轴>>> from functools import partial >>> x = jnp.arange(4) >>> y = 1.0 >>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y) (4,) (1,) Array([1., 2., 3., 4.], dtype=float32)
而,
vmap_method="broadcast_all"
在y
中添加一个大小为4
的轴>>> jax.vmap(partial(fun, vmap_method="broadcast_all"), ... in_axes=(0, None))(x, y) (4,) (4,) Array([1., 2., 3., 4.], dtype=float32)