jax.experimental.host_callback.id_tap

目录

jax.experimental.host_callback.id_tap#

jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs)[source]#

主机回调 tap 原语,类似于带有对 tap_func 的调用的标识函数。

警告

截至 2024 年 3 月 20 日,host_callback API 已被弃用。该功能被 新的 JAX 外部回调 所取代。请参见 google/jax#20385

id_tap 在语义上类似于标识函数,但它有一个副作用,即会调用用户定义的 Python 函数,并将参数的运行时值传递给它。

参数:
  • tap_func – 类似于 tap_func(arg, transforms) 的调用 tap 函数,其中 arg 如以下所述,transforms 是应用于 JAX 变换的序列,形式为 (name, params)。如果可选参数 tap_with_device 为 True,则调用还包含从其获取值的设备,作为关键字参数:tap_func(arg, transforms, device=dev)

  • arg – 传递给 tap 函数的参数,可以是 JAX 类型树。

  • result – 如果给定,则指定 id_tap 的返回值。此值不会传递给 tap 函数,实际上不会从设备发送到主机。如果未指定 result 参数,则 id_tap 的返回值为 arg

  • tap_with_device – 如果为 True,则 tap 函数将使用获取 tap 的设备作为关键字参数进行调用。

  • device_index – 指定在 SPMD 程序中调用 tap 函数的设备。仅在使用 outfeed 实现机制时有效,即在 CPU 上无效,除非 –jax_host_callback_outfeed=True。

  • callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调类型。参见 google/jax#20385.

返回值:

argresult(如果给定)。

执行顺序由数据依赖性决定:在所有参数和 result 的值(如果存在)计算完毕后,以及使用返回值之前。id_tap 的至少一个返回值必须在剩余计算中使用,否则此操作将无效。

即使在加速器上执行的代码,以及在 JAX 变换下的代码,tap 也能正常工作。

有关更多详细信息,请参见 jax.experimental.host_callback 模块文档。