外部回调#

本教程概述了如何使用各种回调函数,这些函数允许 JAX 运行时在主机上执行 Python 代码。JAX 回调的示例包括 jax.pure_callback()jax.experimental.io_callback()jax.debug.callback()。即使在 JAX 转换下运行时,包括 jit()vmap()grad(),您也可以使用它们。

为什么需要回调?#

回调例程是一种在运行时执行代码的主机端方法。举个简单的例子,假设您想在计算过程中打印某个变量的。使用简单的 Python print() 语句,它看起来像这样:

import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>

打印的不是运行时值,而是跟踪时的抽象值(如果您不熟悉 JAX 中的跟踪,可以在 跟踪 中找到一个很好的入门知识)。

要在运行时打印值,您需要一个回调,例如 jax.debug.print()(您可以在 调试简介 中了解有关调试的更多信息)

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)
intermediate value: 3

它的工作原理是将 y 的运行时值作为 CPU jax.Array 传递回主机进程,主机可以在其中打印它。

回调的类型#

在早期版本的 JAX 中,只有一种可用的回调,在 jax.experimental.host_callback() 中实现。host_callback 例程有一些缺陷,现在已被弃用,转而使用为不同情况设计的几种回调。

(您之前使用的 jax.debug.print() 函数是 jax.debug.callback() 的包装器)。

从用户的角度来看,这三种回调类型的主要区别在于它们允许的转换和编译器优化。

回调函数

支持返回值

jit

vmap

grad

scan/while_loop

保证执行

jax.pure_callback()

❌¹

jax.experimental.io_callback()

✅/❌²

✅³

jax.debug.callback()

¹ jax.pure_callback 可以与 custom_jvp 一起使用,使其与自动微分兼容

² jax.experimental.io_callback 仅在 ordered=False 时才与 vmap 兼容。

³ 请注意,io_callbackscan/while_loopvmap 具有复杂的语义,其行为可能会在未来版本中发生变化。

探索 pure_callback#

jax.pure_callback() 通常是当您想要在主机端执行纯函数时应该选择的回调函数:即没有副作用(例如打印值、从磁盘读取数据、更新全局状态等)的函数。

您传递给 jax.pure_callback() 的函数实际上不必是纯函数,但 JAX 的转换和高阶函数会将其视为纯函数,这意味着它可能会被静默地省略或多次调用。

import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

由于 pure_callback 可以被省略或重复,因此它与 jitvmap 等转换以及 scanwhile_loop 等高阶原语开箱即用兼容:”

jax.jit(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
jax.vmap(f)(x)
/tmp/ipykernel_889/3691550925.py:11: DeprecationWarning: The default behavior of pure_callback under vmap will soon change. Currently, the default behavior is to generate a sequential vmap (i.e. a loop), but in the future the default will be to raise an error. To keep the current default, set vmap_method='sequential'.
  return jax.pure_callback(f_host, result_shape, x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

然而,由于 JAX 无法内省回调的内容,pure_callback 具有未定义的自动微分语义

jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

有关将 pure_callbackjax.custom_jvp() 一起使用的示例,请参阅下面的示例:带有 custom_jvppure_callback

根据设计,传递给 pure_callback 的函数被视为没有副作用:这样做的结果之一是,如果函数输出未被使用,编译器可能会完全消除回调

def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

f1 中,回调的输出在函数的返回值中使用,因此会执行回调,并且我们看到打印的输出。另一方面,在 f2 中,回调的输出未使用,因此编译器会注意到这一点并消除函数调用。这些是没有副作用的函数的回调的正确语义。

探索 io_callback#

jax.pure_callback() 不同,jax.experimental.io_callback() 明确用于非纯函数,即具有副作用的函数。

例如,这是一个回调到全局主机端 numpy 随机生成器的回调。这是一个不纯的操作,因为在 numpy 中生成随机数的副作用是随机状态会更新(请注意,这只是 io_callback 的一个玩具示例,而不是在 JAX 中生成随机数的推荐方法!)。

from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

io_callback 默认情况下与 vmap 兼容

jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

但是请注意,这可能会以任意顺序执行映射的回调。因此,例如,如果您在 GPU 上运行此操作,则映射输出的顺序可能因运行而异。

如果回调的顺序必须保留,您可以设置 ordered=True,在这种情况下,尝试 vmap 将会引发错误

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.

另一方面,无论是否强制排序,scanwhile_loop 都与 io_callback 一起使用

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

pure_callback 类似,如果向 io_callback 传递微分变量,则自动微分会失败

jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.

但是,如果回调不依赖于微分变量,它将执行

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);
hello

pure_callback 不同,在这种情况下,编译器不会删除回调执行,即使回调的输出在后续计算中未使用。

探索 debug.callback#

pure_callbackio_callback 都对它们调用的函数的纯度强制执行一些假设,并以各种方式限制 JAX 转换和编译机制可能执行的操作。 debug.callback 本质上对回调函数不作任何假设,因此回调的操作完全反映了 JAX 在程序执行过程中的操作。此外,debug.callback 不能向程序返回任何值。

from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);
log: 1.0

debug 回调与 vmap 兼容

x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0

并且也与 grad 和其他自动微分转换兼容

jax.grad(f)(1.0);
log: 1.0

这使得 debug.callbackpure_callbackio_callback 更适合用于通用调试。

示例:带有 custom_jvppure_callback#

结合 jax.pure_callback()jax.custom_jvp 是一种利用 jax.pure_callback() 的强大方法。(有关 jax.custom_jvp() 的更多详细信息,请参阅 JAX可转换Python函数的自定义导数规则)。

假设您想为 scipy 或 numpy 中尚未在 jax.scipyjax.numpy 封装器中提供的函数创建一个 JAX 兼容的封装器。

在这里,我们将考虑为 scipy.special.jv 中提供的第一类贝塞尔函数创建封装器。您可以从定义一个简单的 pure_callback() 开始

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # You use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

这使我们可以从转换后的 JAX 代码中调用 scipy.special.jv(),包括被 jit()vmap() 转换时。

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

这是使用 jit() 的相同结果

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

这是再次使用 vmap() 的相同结果

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

但是,如果您调用 grad(),您将收到错误,因为没有为此函数定义自动微分规则

jax.grad(j1)(z)
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

让我们为此定义一个自定义梯度规则。查看第一类贝塞尔函数的定义,您会发现关于参数 z 的导数有一个相对简单的递归关系

\[\begin{split} d J_\nu(z) = \left\{ \begin{eqnarray} -J_1(z),\ &\nu=0\\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}\]

关于 \(\nu\) 的梯度更复杂,但由于我们将 v 参数限制为整数类型,因此为了本示例,您无需担心它的梯度。

您可以使用 jax.custom_jvp() 为您的回调函数定义此自动微分规则

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

现在,计算函数的梯度将正常工作

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

此外,由于我们已根据 jv 本身定义了您的梯度,因此 JAX 的架构意味着您可以免费获得二阶和更高阶的导数

jax.hessian(j1)(2.0)
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
Array(-0.4003078, dtype=float32, weak_type=True)

请记住,尽管这一切都可以与 JAX 正确配合使用,但是每次调用基于回调的 jv 函数都会导致将输入数据从设备传递到主机,并将 scipy.special.jv() 的输出从主机传递回设备。

在 GPU 或 TPU 等加速器上运行时,每次调用 jv 时,这种数据移动和主机同步可能会导致明显的开销。

但是,如果在单个 CPU 上运行 JAX(其中“主机”和“设备”位于同一硬件上),则 JAX 通常会以快速、零复制的方式进行此数据传输,从而使此模式成为扩展 JAX 功能的相对简单的方法。