jax.experimental.host_callback.id_tap#
- jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs)[source]#
主机回调 tap 原语,类似于带有对
tap_func
的调用的标识函数。警告
截至 2024 年 3 月 20 日,host_callback API 已被弃用。该功能被 新的 JAX 外部回调 所取代。请参见 google/jax#20385。
id_tap
在语义上类似于标识函数,但它有一个副作用,即会调用用户定义的 Python 函数,并将参数的运行时值传递给它。- 参数:
tap_func – 类似于
tap_func(arg, transforms)
的调用 tap 函数,其中arg
如以下所述,transforms
是应用于 JAX 变换的序列,形式为(name, params)
。如果可选参数 tap_with_device 为 True,则调用还包含从其获取值的设备,作为关键字参数:tap_func(arg, transforms, device=dev)
。arg – 传递给 tap 函数的参数,可以是 JAX 类型树。
result – 如果给定,则指定
id_tap
的返回值。此值不会传递给 tap 函数,实际上不会从设备发送到主机。如果未指定result
参数,则id_tap
的返回值为arg
。tap_with_device – 如果为 True,则 tap 函数将使用获取 tap 的设备作为关键字参数进行调用。
device_index – 指定在 SPMD 程序中调用 tap 函数的设备。仅在使用 outfeed 实现机制时有效,即在 CPU 上无效,除非 –jax_host_callback_outfeed=True。
callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调类型。参见 google/jax#20385.
- 返回值:
arg
或result
(如果给定)。
执行顺序由数据依赖性决定:在所有参数和
result
的值(如果存在)计算完毕后,以及使用返回值之前。id_tap
的至少一个返回值必须在剩余计算中使用,否则此操作将无效。即使在加速器上执行的代码,以及在 JAX 变换下的代码,tap 也能正常工作。
有关更多详细信息,请参见
jax.experimental.host_callback
模块文档。