jax.pure_callback

jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[source]#

调用纯 Python 回调。在 jit()/vmap()/等等下工作。

有关更多解释,请参阅 外部回调.

pure_callback 允许在 JIT 化的 JAX 函数中调用 Python 函数。输入 callback 将传递放置在本地 CPU 上的 JAX 数组,它也应该返回 CPU 上的 JAX 数组。

回调函数被视为函数式纯函数,这意味着它没有副作用,其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如,当通过 vmap()pmap() 进行变换时),或者当例如 jit 装饰的函数的输出与其值没有数据依赖性时,完全不调用。如果数据依赖性允许,纯回调函数也可以重新排序。

vmap 时,行为将取决于 vectorized 关键字参数的值。当 vectorizedTrue 时,假定回调函数服从 jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])。因此,回调函数将直接在批处理输入上调用(其中批处理轴是前导维度)。此外,回调函数应返回具有相应前导批处理轴的输出。如果未向量化,callback 将在批处理轴上按顺序映射。例如,如果 callback = lambda x, y: np.matmul(x, y),那么我们可以自由设置 vectorized=True,因为 np.matmul 函数处理任意前导批处理维度。

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。假定回调函数是纯函数(即没有副作用):如果传递了一个非纯函数,它可能会以意想不到的方式运行,尤其是在转换下。该可调用对象将被传递数组的 PyTrees 作为参数,并应返回与 result_shape_dtypes 匹配的数组的 PyTree。

  • result_shape_dtypes (Any) – 其叶子具有 shapedtype 属性的 pytree,其结构与运行时回调函数的预期输出匹配。 jax.ShapeDtypeStruct 通常用于定义叶子值。

  • *args (Any) – 要传递给回调函数的参数

  • sharding (SingleDeviceSharding | None | None) – 可选的分片,指定应调用回调函数的设备。

  • vectorized (bool) – 布尔值,指定回调函数是否可以以向量化方式运行。

  • **kwargs (Any) – 要传递给回调函数的关键字参数

返回:

一个 pytree 的 jax.Array 对象,其结构与

result_shape_dtypes.

返回类型:

结果

另请参阅