jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[源代码]#

调用纯 Python 回调。在 jit()/vmap()/等等下工作。

有关更多说明,请参阅外部回调

pure_callback 允许在 JIT 编译的 JAX 函数中调用 Python 函数。输入的 callback 将被传递放置在本地 CPU 上的 JAX 数组,并且它还应该在 CPU 上返回 JAX 数组。

该回调被视为功能纯函数,意味着它没有副作用,其输出值仅取决于其参数值。因此,可以安全地多次调用它(例如,当被 vmap()pmap() 转换时),或者当例如一个由 jit 修饰的函数的输出与其值没有数据依赖关系时,可以根本不调用它。如果数据依赖性允许,纯回调也可以重新排序。

当使用 vmap 时,行为将取决于 vmap_method 的值。

  • 在没有显式 vmap_method 的情况下对回调调用 vmap() 已被弃用,最终会引发 NotImplementedError

  • vmap_method="sequential" 使用 map() 循环遍历批处理的参数,为每个批处理元素调用一次 callback

  • vmap_method="expand_dims" 在未批处理的输入的前导维度上添加大小为 1 的新轴,并调用 callback

  • vmap_method="broadcast_all" 的行为类似于 expand_dims,但输入会被平铺到预期的批处理形状。

如有必要,可以使用 vmap_method="legacy_vectorized" 恢复已弃用的 vectorized=True 参数提供的旧行为。

当前默认行为是在未指定时使用 vmap_method="sequential",但此行为已弃用,将来,除非显式指定 vmap_method,否则默认行为将引发 NotImplementedError

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。假定回调是一个纯函数(即没有副作用的函数):如果传递一个不纯的函数,它可能会以意想不到的方式运行,尤其是在转换下。该可调用对象将传递数组的 PyTree 作为参数,并应返回一个与 result_shape_dtypes 匹配的数组 PyTree。

  • result_shape_dtypes (Any) – 其叶子具有 shapedtype 属性的 pytree,其结构与运行时回调函数的预期输出匹配。 jax.ShapeDtypeStruct 通常用于定义叶子值。

  • *args (Any) – 要传递给回调函数的参数

  • sharding (SingleDeviceSharding | None | None) – 可选的分片,用于指定应从中调用回调的设备。

  • vmap_method (str | None | None) – 字符串,指定回调在 vmap() 下如何转换,如上所述。

  • **kwargs (Any) – 要传递给回调函数的关键字参数

  • vectorized (bool | None | DeprecatedArg)

返回:

一个 jax.Array 对象的 pytree,其结构与

result_shape_dtypes 的结构相匹配.

返回类型:

结果

另请参阅

示例

如上所述,pure_callbackvmap() 下的行为由 vmap_method 参数控制。考虑一些明确的示例来演示语义是很有用的。例如,考虑以下函数

>>> def callback(x, y):
...   print(jnp.shape(x), jnp.shape(y))
...   return x + y
>>> def fun(x, y, *, vmap_method):
...   shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
...   dtype = jnp.result_type(x, y)
...   out_type = jax.ShapeDtypeStruct(shape, dtype)
...   return jax.pure_callback(callback, out_type, x, y,
...                            vmap_method=vmap_method)

使用 vmap_method="expand_dims" 调用此函数会在 y 中添加一个大小为 1 的新轴

>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)

而,vmap_method="broadcast_all"y 中添加一个大小为 4 的轴

>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
...          in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)