jax.experimental.io_callback#
- jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[source]#
调用一个不纯的 Python 回调。
有关更多解释,请参阅 外部回调.
- 参数::
callback (Callable[..., Any]) – 在主机上执行的函数。假设它是一个不纯函数。如果
callback
是纯函数,使用jax.pure_callback()
可能会导致更有效的执行。result_shape_dtypes (任何类型) – 叶节点具有
shape
和dtype
属性的 Pytree,其结构与回调函数在运行时预期的输出相匹配。jax.ShapeDtypeStruct
通常用于定义叶节点值。*args (任何类型) – 要传递给回调函数的参数
sharding (单设备分片 | 无 | 无) – 指定应从哪个设备调用回调函数的可选分片。
ordered (布尔值) – 指定回调函数的顺序调用是否必须有序的布尔值。
**kwargs (任何类型) – 要传递给回调函数的关键字参数
- 返回值:
- 一个
jax.Array
对象的 Pytree,其结构与 result_shape_dtypes
.
- 一个
- 返回类型:
result
另请参见
jax.pure_callback()
: 为纯函数设计的回调函数。jax.debug.callback()
: 为通用调试设计的回调函数。jax.debug.print()
: 为打印设计的回调函数。