jax.experimental.host_callback 模块#

用于从 JAX 加速器代码调用主机上的 Python 函数的原语。

警告

从 2024 年 3 月 20 日起,host_callback API 已弃用。此功能已被 新的 JAX 外部回调 替代。请参阅 google/jax#20385

本模块介绍了主机回调函数call()id_tap()id_print(),这些函数将它们的实参从设备发送到主机,并在主机上调用用户定义的 Python 函数,可以选择将结果返回到设备计算中。

下面我们将展示如何使用这些函数。我们从call()开始,并讨论从 JAX 调用 CPU 上任意 Python 函数的示例,例如,使用 NumPy CPU 自定义内核。然后我们将展示id_tap()id_print()的使用方法,它们受到限制,不能将值从主机返回到设备。这些原语通常更快,因为它们与设备代码异步执行。特别是,它们可以用于访问和调试 JAX 代码。

使用call()调用主机函数并将结果返回到设备#

使用call()在主机上调用计算并将 NumPy 数组返回到设备计算中。主机计算很有用,例如,当设备计算需要一些需要在主机上进行 I/O 的数据时,或者它需要一个在主机上可用的库,而您不想在 JAX 中对其进行编码时。例如,JAX 中一般矩阵的特征值分解在 TPU 上不起作用。我们可以使用主机计算从任何 JAX 加速器计算中调用 Numpy 实现。

# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))

call()函数和 Python 主机函数都接受一个实参并返回一个结果,但这些可以是 pytrees。请注意,我们必须使用result_shape关键字实参告诉call()期望主机调用的形状和数据类型。这很重要,因为设备代码是根据此期望进行编译的。如果实际调用产生不同的结果形状,则将在运行时引发错误。一般来说,**此类错误以及主机计算引发的异常可能难以调试**。请参阅下面的调试部分。这是call()的问题,但不是id_tap()的问题,因为对于后者,设备代码不期望返回值。

call()API 可以用在 jit 或 pmap 计算中,或用在 cond/scan/while 控制流中。当用在jax.pmap()内部时,将从每个参与的设备分别调用主机。

def host_sin(x, *, device):
  # The ``device`` argument is passed due to ``call_with_device=True`` below.
  print(f"Invoking host_sin with {x.shape} on {device}")
  return np.sin(x)

# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
                            result_shape=x,
                            # Ask that the `host_sin` function be passed `device=dev`
                            call_with_device=True))(
         np.ones((2, 4), dtype=np.float32))

# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1

请注意,call()不支持任何 JAX 变换,但如下所示,我们可以利用现有的JAX 中自定义微分支持。

使用id_tap()调用主机上的 Python 函数,不返回任何值#

id_tap()id_print()call()的特殊情况,当您只需要 Python 回调的副作用时。这些函数的优点是,一旦实参已发送到主机,设备计算就可以继续进行,而无需等待 Python 回调返回。对于id_tap(),您可以指定要调用的 Python 回调,而id_print()使用一个内置回调,该回调将实参打印到主机上的stdout。传递给id_tap()的 Python 函数接受两个位置实参(从设备计算中获取的值以及下面描述的transforms元组)。可选地,可以向函数传递一个关键字实参device,其中包含从中获取值的设备。

一些示例

def host_func(arg, transforms):
   ...do something with arg...

# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)

# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x))  # The argument can be a pytree

# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True)  # Pass the device to the tap

# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)

# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))

上述所有示例都可以改编为使用id_print(),不同之处在于id_print()将在主机上打印位置实参,以及任何其他 kwargs 和自动 kwarg transforms

使用barrier_wait()等待所有回调执行完成#

如果您的 Python 回调具有副作用,您可能需要等到计算完成才能确保已观察到副作用。您可以为此目的使用barrier_wait()函数。

accumulator = []
def host_log(arg, transforms):
  # We just record the arguments in a list
  accumulator.append(arg)


def device_fun(x):
  id_tap(host_log, x)
  id_tap(host_log, 2. * x)

jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)

# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.

请注意,barrier_wait()将在每个jax.local_devices()上启动一个带有单个轻触的小型计算,并将等待接收所有这些轻触。

如果所有回调都是call(),则使用barrier_wait()的替代方法是等待计算结束。

accumulator = p[]
def host_log(arg):
  # We just record the arguments in a list
  accumulator.append(arg)
  return 0.  #  return something


def device_fun(c):
  y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  return y + z  # return something that uses both results

res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()

并行化变换下的行为#

在存在jax.pmap()的情况下,代码将在多个设备上运行,并且每个设备将独立轻触其值。对于id_print()id_tap(),使用tap_with_device选项可能会有所帮助,以便您可以查看哪个设备发送了哪些数据。

jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.)  # from the first device
# device=cpu:1 what=x,x^2: (4., 16.)  # from the second device

当在多个主机上的多个设备上使用jax.pmap()时,每个主机都将收到来自其所有本地设备的回调,并带有与每个设备切片对应的操作数。对于call(),回调必须仅向每个设备返回与其对应的设备切片。

当使用实验性的pjit.pjit()时,代码将在输入的不同分片上的多个设备上运行。主机回调的当前实现将确保单个设备将收集并输出整个操作数,在一个回调中。回调函数应该返回整个数组,然后将其作为单个输入发送到发出输出的同一设备。然后,此设备负责将所需的分片发送到其他设备。

with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
  pjit.pjit(power3, in_shardings=(P("d"),),
            out_shardings=(P("d"),))(np.array([3., 4.]))

# device=TPU:0 what=x,x^2: ( [3., 4.],
#                            [9., 16.] )

请注意,如果操作数已跨设备分片,则在一个设备上收集操作数可能会导致 OOM。

当在多个主机上的多个设备上使用pjit.pjit()时,只有设备 0(相对于网格)的主机将收到回调,操作数将从所有主机上的所有参与设备收集。对于call(),回调必须返回所有主机上所有设备的整个数组。

JAX 自动微分变换下的行为#

当在 JAX 自动微分变换下使用时,主机回调函数仅对原始值进行操作。请考虑以下示例

def power3(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  return y * x

power3(3.)
# what: x,x^2 : (3., 9.)

(您可以在 host_callback_test.HostCallbackTapTest.test_tap_transforms 中看到这些示例的测试结果。)

当在 jax.jvp() 下使用时,将只会有一个包含原始值的回调

jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)

类似地,对于 jax.grad(),我们只从前向计算中获得一个回调

jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)

如果您想在 jax.jvp() 期间对切线调用回调,可以使用 custom_jvp。例如,您可以定义一个除了其 custom_jvp 将打印切线之外什么都不做的函数

@jax.custom_jvp
def print_tangents(arg):
  return None

@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
  arg_dot, = tangents
  hcb.id_print(arg_dot, what="tangents")
  return primals, tangents

然后,您可以在想要截取切线的位置使用此函数

def power3_with_tangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  print_tangents((x, y))
  return y * x

jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)

对于 jax.grad() 期间的余切,您可以执行类似的操作。这次您必须小心地在计算的其余部分使用您想要截取其余切的值。因此,我们使 print_cotangents 返回其参数

@jax.custom_vjp
def print_cotangents(arg):
  # Must return the argument for which we want the cotangent.
  return arg

# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
  return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
  hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
  return ct_b,

print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)

def power3_with_cotangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
  (x1, y1) = print_cotangents((x, y))
  # Must use the output of print_cotangents
  return y1 * x1

jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)

如果您使用 ad_checkpoint.checkpoint() 重新计算反向传播的残差,则来自原始计算的回调将被调用两次

jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)

回调的顺序依次为:内部 power3 的原始计算、外部 power3 的原始计算以及内部 power3 的残差的重新计算。

jax.vmap 下的行为#

主机回调函数 id_print()id_tap() 支持向量化变换 jax.vmap()

对于 jax.vmap(),回调的参数是批处理的,并且回调函数会传递一个额外的特殊 transforms,其中包含一个变换描述符列表,格式为 ("batch", {"batch_dims": ...}),其中 ...` 表示被截取值的批处理维度(每个参数一个条目,` None` 表示广播的参数)。

jax.vmap(power3)(np.array([2., 3.])) # transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])

请参阅 id_tap()id_print()call() 的文档。

有关更多用法示例,请参阅 tests/host_callback_test.py。

使用 call() 调用 TensorFlow 函数,支持反向模式自动微分#

主机计算的另一个可能用途是调用为其他框架(如 TensorFlow)编写的库。在这种情况下,通过使用 jax.custom_vjp() 机制将自动微分机制委托给 TensorFlow,从而支持主机回调的 JAX 自动微分变得很有意义。

一旦理解了 JAX 自定义 VJP 和 TensorFlow 自动微分机制,这相对容易做到。host_callback_to_tf_test.py 中的 call_tf_full_ad 函数展示了如何实现这一点。此示例还支持任意高阶微分。

请注意,如果您只想从 JAX 调用 TensorFlow 函数,还可以使用 jax2tf.call_tf 函数

使用 call() 在另一个设备上调用 JAX 函数,支持反向模式自动微分#

我们使用主机计算在另一个设备上调用 JAX 计算并不奇怪。参数从加速器发送到主机,然后发送到 JAX 主机计算将在其上运行的外部设备,然后结果被发送回原始加速器。

host_callback_test.py 中的 call_jax_other_device function 函数展示了如何实现这一点。

底层细节和调试#

主机回调函数将为每个设备按设备上发送操作执行的顺序执行。

多个设备的主机回调函数可能会交错执行。来自设备的数据由 JAX 运行时管理的单独线程接收(每个设备一个线程)。运行时维护一个大小可配置的缓冲区(请参阅标志 --jax_host_callback_max_queue_byte_size)。当缓冲区已满时,所有接收线程都会暂停,这最终会导致设备上的计算暂停。运行时为每个设备还有一个额外的线程,用于使用接收到的数据调用 Python 用户函数。如果回调的处理速度很慢,实际上可能会导致运行时缓冲区填满,并最终在设备需要发送某些内容时暂停设备上的计算。有关出站接收器运行时机制的更多详细信息,请参阅 运行时代码

为了暂停执行,直到设备上已启动的计算的所有数据都已到达并已处理,请使用 barrier_wait()

来自用户定义回调函数的异常将与它们的堆栈跟踪一起记录,但接收线程不会停止。相反,会记录最后一个异常,并且随后的 barrier_wait() 将引发 CallbackException(如果在某个 tap 函数中发生了任何异常)。此异常将包含遇到的最后一个异常的文本和堆栈跟踪。

对于必须将结果返回到调用源设备的回调函数(例如 call()),会出现另一个复杂情况。在 CPU/GPU 设备上与 TPU 设备上的处理方式不同。

在 CPU/GPU 设备上,为了避免设备计算卡住等待永远不会到达的结果,如果在处理回调期间发生任何错误(无论是由用户代码本身引发还是由于返回值与预期的 return_shape 不匹配),我们都会向设备发送形状为 int8[12345] 的“伪造”结果。这将使设备计算中止,因为接收到的数据与它期望的数据不同。在 CPU 上,运行时将崩溃并显示一条独特的错误消息

` Check failed: buffer->length() == buffer_length (12345 vs. ...) `

在 GPU 上,错误更友好,并将作为以下内容显示到 Python 程序中

` RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... `

要调试这些消息的根本原因,请参阅调试部分。

在 TPU 设备上,目前没有用于输入馈送的形状检查,因此我们采取更安全的做法,即在发生错误时不发送此伪造结果。这意味着计算将挂起,并且不会引发异常(但回调函数中的任何异常仍将显示在日志中)。

当前的实现使用 XLA 提供的出站馈送机制。该机制本身非常原始,因为接收器必须准确知道每个传入数据包的形状以及预期多少个数据包。这使得难以在同一计算中使用多种类型的数据,并且实际上不可能在条件语句或非恒定迭代次数的循环中使用它。此外,直接使用出站馈送机制的代码不能由 JAX 变换。主机回调函数解决了所有这些限制。此处介绍的 tap API 使得轻松地将出站馈送机制用于多种目的,同时支持所有变换。

**请注意,在使用主机回调函数后,您不能直接使用 lax.outfeed**。如果稍后需要使用 lax.outfeed,您可能需要 stop_outfeed_receiver()

由于实际对回调函数的调用是从 C++ 接收器发起的,因此调试这些调用可能很困难。特别是,堆栈跟踪将不包含调用代码。您可以使用标志jax_host_callback_inline(或环境变量JAX_HOST_CALLBACK_INLINE)来确保对回调的调用内联。这仅在调用位于分段上下文之外时有效(jit()或控制流原语)。

C++ 接收器在第一次调用id_tap()时自动启动。为了正确停止它,在启动时,会注册一个atexit 处理程序来调用barrier_wait(),并使用日志名称“at_exit”。

有一些环境变量可用于打开 C++ 输出馈送 接收器后端 的日志记录。

  • TF_CPP_MIN_LOG_LEVEL=0:将打开 INFO 日志记录,以下所有内容都需要此日志记录。

  • TF_CPP_MIN_VLOG_LEVEL=3:将使所有 VLOG 日志记录(最高 3 级)的行为类似于 INFO 日志。这可能有点多,但您将看到哪些模块正在记录相关信息,然后您可以选择要记录哪些模块的信息。

  • TF_CPP_VMODULE=<module_name>=3(模块名称可以是 C++ 或 Python,不带扩展名)。

您还应该使用--verbosity=2 标志,以便查看来自 Python 的日志。

例如,您可以尝试在host_callback 模块中启用日志记录:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

如果要启用较低级别实现模块中的日志记录,请尝试:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

(对于 Bazel 测试,请使用 –test_arg=–vmodule=…

待办事项
  • 更多性能测试。

  • 探索使用外部编译在 TPU 上实现。

  • 探索使用 XLA CustomCall 在 CPU 和 GPU 上实现。

API#

id_tap(tap_func, arg, *[, result, ...])

主机回调 tap 原语,类似于带有对tap_func 的调用的恒等函数。

id_print(arg, *[, result, tap_with_device, ...])

类似于id_tap(),带有一个打印 tap 函数。

call(callback_func, arg, *[, result_shape, ...])

对主机进行调用,并期望得到结果。

barrier_wait([logging_name])

阻塞调用线程,直到所有当前输出馈送都已处理。

CallbackException

表示某些回调函数存在异常。