jax.experimental.io_callback

目录

jax.experimental.io_callback#

jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[source]#

调用一个不纯的 Python 回调。

有关更多解释,请参阅 外部回调.

参数::
  • callback (Callable[..., Any]) – 在主机上执行的函数。假设它是一个不纯函数。如果 callback 是纯函数,使用 jax.pure_callback() 可能会导致更有效的执行。

  • result_shape_dtypes (任何类型) – 叶节点具有 shapedtype 属性的 Pytree,其结构与回调函数在运行时预期的输出相匹配。 jax.ShapeDtypeStruct 通常用于定义叶节点值。

  • *args (任何类型) – 要传递给回调函数的参数

  • sharding (单设备分片 | | ) – 指定应从哪个设备调用回调函数的可选分片。

  • ordered (布尔值) – 指定回调函数的顺序调用是否必须有序的布尔值。

  • **kwargs (任何类型) – 要传递给回调函数的关键字参数

返回值:

一个 jax.Array 对象的 Pytree,其结构与

result_shape_dtypes.

返回类型:

result

另请参见