jax.experimental.io_callback#
- jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[源]#
调用一个不纯的 Python 回调函数。
更多解释,请参阅外部回调。
- 参数:
callback (Callable[..., Any]) – 在主机上执行的函数。 假设它是一个不纯的函数。 如果
callback
是纯函数,使用jax.pure_callback()
可能会带来更高效的执行。result_shape_dtypes (Any) – pytree,其叶子具有
shape
和dtype
属性,其结构与运行时回调函数的预期输出相匹配。jax.ShapeDtypeStruct
通常用于定义叶子值。*args (Any) – 要传递给回调函数的参数
sharding (SingleDeviceSharding | None | None) – 可选的分片,指定应从中调用回调的设备。
ordered (bool) – 布尔值,指定对回调的顺序调用是否必须排序。
**kwargs (Any) – 要传递给回调函数的关键字参数
- 返回:
- 一个
jax.Array
对象的 pytree,其结构与 result_shape_dtypes 的结构相匹配
.
- 一个
- 返回类型:
result
另请参阅
jax.pure_callback()
:为纯函数设计的回调函数。jax.debug.callback()
:为通用调试设计的回调函数。jax.debug.print()
:为打印设计的回调函数。