jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[source]#
调用纯 Python 回调。在
jit()
/vmap()
/等等下工作。有关更多解释,请参阅 外部回调.
pure_callback
允许在 JIT 化的 JAX 函数中调用 Python 函数。输入callback
将传递放置在本地 CPU 上的 JAX 数组,它也应该返回 CPU 上的 JAX 数组。回调函数被视为函数式纯函数,这意味着它没有副作用,其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如,当通过
vmap()
或pmap()
进行变换时),或者当例如 jit 装饰的函数的输出与其值没有数据依赖性时,完全不调用。如果数据依赖性允许,纯回调函数也可以重新排序。当 vmap 时,行为将取决于
vectorized
关键字参数的值。当vectorized
为True
时,假定回调函数服从jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])
。因此,回调函数将直接在批处理输入上调用(其中批处理轴是前导维度)。此外,回调函数应返回具有相应前导批处理轴的输出。如果未向量化,callback
将在批处理轴上按顺序映射。例如,如果callback = lambda x, y: np.matmul(x, y)
,那么我们可以自由设置vectorized=True
,因为np.matmul
函数处理任意前导批处理维度。- 参数:
callback (Callable[..., Any]) – 在主机上执行的函数。假定回调函数是纯函数(即没有副作用):如果传递了一个非纯函数,它可能会以意想不到的方式运行,尤其是在转换下。该可调用对象将被传递数组的 PyTrees 作为参数,并应返回与
result_shape_dtypes
匹配的数组的 PyTree。result_shape_dtypes (Any) – 其叶子具有
shape
和dtype
属性的 pytree,其结构与运行时回调函数的预期输出匹配。jax.ShapeDtypeStruct
通常用于定义叶子值。*args (Any) – 要传递给回调函数的参数
sharding (SingleDeviceSharding | None | None) – 可选的分片,指定应调用回调函数的设备。
vectorized (bool) – 布尔值,指定回调函数是否可以以向量化方式运行。
**kwargs (Any) – 要传递给回调函数的关键字参数
- 返回:
- 一个 pytree 的
jax.Array
对象,其结构与 result_shape_dtypes
.
- 一个 pytree 的
- 返回类型:
结果
另请参阅
jax.experimental.io_callback()
: 为非纯函数设计的回调函数。jax.debug.callback()
: 为通用调试设计的回调函数。jax.debug.print()
: 为打印设计的回调函数。