jax.experimental.host_callback.call

目录

jax.experimental.host_callback.call#

jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK)[source]#

调用主机,并期望得到结果。

警告

从 2024 年 3 月 20 日起,host_callback API 已被弃用。此功能已被 新的 JAX 外部回调 取代。请参阅 google/jax#20385

参数::
  • callback_func (Callable) – 要在主机上调用的 Python 函数,以 callback_func(arg) 的形式。如果可选参数 call_with_device 为 True,则调用还包括带有调用来源设备的 device 关键字参数:callback_func(arg, device=dev)。此函数必须返回一个由 numpy 多维数组组成的 pytree。

  • arg – 传递给回调函数的参数,可以是 JAX 类型组成的 pytree。

  • result_shape – 描述结果预期形状和数据类型的值。这可以是一个数字标量,从中获取形状和数据类型,也可以是一个具有 .shape.dtype 属性的对象。如果回调的结果是 pytree,则 result_shape 也应该是一个具有相同结构的 pytree。特别是,如果函数没有结果,result_shape 可以是 ()None。包含 call 的设备代码使用预期的结果形状和数据类型进行编译,如果实际的 callback_func 调用返回不同类型的结果,将在运行时引发错误。

  • call_with_device – 如果为 True,则使用调用来源的设备作为关键字参数调用回调函数。

  • device_index – 指定在 SPMD 程序中从哪个设备调用 tap 函数。仅在使用 outfeed 实现机制时有效,即,除非 –jax_host_callback_outfeed=True,否则在 CPU 上无效。

  • callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调类型。请参阅 google/jax#20385

返回值:

callback_func 调用的结果。

有关更多详细信息,请参阅 jax.experimental.host_callback 模块文档。