编译后的打印和断点#

jax.debug 包提供了一些用于检查编译函数内部值的实用工具。

使用 jax.debug.print 和其他调试回调进行调试#

总结:使用 jax.debug.print() 将追踪到的数组值打印到已编译(例如,用 jax.jitjax.pmap 装饰的)函数的标准输出中

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

通过一些转换,例如 jax.gradjax.vmap,你可以使用 Python 的内置函数 print 来打印数值。但 print 无法与 jax.jitjax.pmap 一起使用,因为这些转换会延迟数值计算。所以请使用 jax.debug.print 代替!

从语义上讲,jax.debug.print 大致等同于以下 Python 函数

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

区别在于它可以被 JAX 阶段化和转换。有关更多详细信息,请参阅 API 参考

请注意,fmt 不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print,我们希望延迟格式化直到稍后。

何时使用“debug”打印?#

你应该在 JAX 转换(如 jitvmap 等)中使用 jax.debug.print 来打印动态(即跟踪)数组值。对于静态值的打印(如数组形状或数据类型),你可以使用正常的 Python print 语句。

为什么是“debug”打印?#

为了调试,jax.debug.print 可以揭示有关计算如何执行的信息

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

请注意,打印的结果顺序不同!

通过揭示这些内部工作机制,jax.debug.print 的输出不遵守 JAX 的通常语义保证,例如 jax.vmap(f)(xs)jax.lax.map(f, xs) 以不同的方式计算相同的结果。然而,这些计算顺序细节正是我们调试时可能想要查看的内容!

因此,在调试时使用 jax.debug.print,而不要在语义保证很重要的场合使用它。

更多 jax.debug.print 的示例#

除了上面使用 jitvmap 的示例之外,这里还有一些需要牢记的示例。

jax.pmap 下打印#

当使用 jax.pmap 时,jax.debug.print 可能被重新排序!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0

jax.grad 下打印#

jax.grad 下,jax.debug.print 仅在正向传播时打印

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

这种行为类似于 Python 的内置函数 printjax.grad 下的工作方式。但是通过在这里使用 jax.debug.print,即使调用者应用了 jax.jit,行为也保持一致。

要在反向传播时打印,只需使用 jax.custom_vjp

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0

在其他转换中打印#

jax.debug.print 也适用于其他转换,如 pjit

使用 jax.debug.callback 进行更多控制#

实际上,jax.debug.print 是对 jax.debug.callback 的一个简单方便的包装器,可以直接使用它来更好地控制字符串格式化,甚至输出类型。

从语义上讲,jax.debug.callback 大致等同于以下 Python 函数

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

jax.debug.print 一样,这些回调仅应用于调试输出,如打印或绘图。打印和绘图非常安全,但如果你将它用于其他任何用途,它在转换下的行为可能会让你感到意外。例如,使用 jax.debug.callback 进行计时操作是不安全的,因为回调可能会被重新排序并且是异步的(见下文)。

关键点#

与大多数 JAX API 一样,如果你不小心,jax.debug.print 会给你带来麻烦。

打印结果的顺序#

当对 jax.debug.print 的不同调用涉及彼此不依赖的参数时,它们在被阶段化(例如,通过 jax.jit)时可能会被重新排序

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

为什么?在幕后,编译器获取阶段化计算的函数表示,其中 Python 函数的命令式顺序丢失,只有数据依赖关系保留下来。对于功能纯净的代码,这种改变对用户来说是不可见的,但在存在打印等副作用的情况下,它就变得明显了。

要保留 Python 函数中编写的 jax.debug.print 的原始顺序,可以使用 jax.debug.print(..., ordered=True),这将确保打印的相对顺序得到保留。但是,使用 ordered=True 会在 jax.pmap 和其他涉及并行的 JAX 转换下引发错误,因为在并行执行下无法保证顺序。

异步回调#

根据后端的不同,jax.debug.print 可能会异步发生,即不会在你的主程序线程中发生。这意味着即使你的 JAX 函数已经返回了值,也可能将值打印到你的屏幕上。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

要阻塞函数中的 jax.debug.print,可以调用 jax.effects_barrier(),它将等待函数中任何剩余的副作用完成

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>

性能影响#

不必要的物化#

虽然 jax.debug.print 被设计为具有最小的性能影响,但它会干扰编译器优化,并可能影响你的 JAX 程序的内存配置文件。

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

在这个例子中,我们在线性层和激活函数之间打印中间值。像 XLA 这样的编译器可以执行融合优化,这可能会避免在内存中物化 logits。但是,当我们在 logits 上使用 jax.debug.print 时,我们正在强制这些中间值物化,这可能会减慢程序速度并增加内存使用量。

此外,当将 jax.debug.printjax.pjit 一起使用时,会发生全局同步,这将使值在单个设备上物化。

回调开销#

jax.debug.print 本身会带来加速器与其主机之间的通信开销。底层机制因后端而异(例如,GPU 与 TPU),但在所有情况下,都需要将打印的值从设备复制到主机。在 CPU 情况下,此开销较小。

此外,当将 jax.debug.printjax.pjit 一起使用时,会发生全局同步,这会增加一些开销。

jax.debug.print 的优点和局限性#

优点#

  • 打印调试简单直观

  • jax.debug.callback 可用于其他无害的副作用

局限性#

  • 添加打印语句是一个手动过程

  • 可能会对性能产生影响

使用 jax.debug.breakpoint() 进行交互式检查#

总结:使用 jax.debug.breakpoint() 暂停 JAX 程序的执行,以检查值

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() 实际上只是对 jax.debug.callback(...) 的应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print 相同的转换行为(例如,对 jax.debug.breakpoint() 进行 vmap 会在映射轴上展开它)。

用法#

在已编译的 JAX 函数中调用 jax.debug.breakpoint() 将在程序遇到断点时暂停程序。你将看到一个类似于 pdb 的提示,允许你检查调用堆栈中的值。与 pdb 不同,你将无法逐步执行,但可以恢复执行。

调试器命令

  • help - 打印可用的命令

  • p - 评估表达式并打印结果

  • pp - 评估表达式并以格式化方式打印结果

  • u(p) - 上移一个栈帧

  • d(own) - 下移一个栈帧

  • w(here)/bt - 打印回溯信息

  • l(ist) - 打印代码上下文

  • c(ont(inue)) - 恢复程序执行

  • q(uit)/exit - 退出程序(在 TPU 上不可用)

示例#

jax.lax.cond 的用法#

当与 jax.lax.cond 结合使用时,调试器可以成为检测 naninf 的有用工具。

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

关键点#

由于 jax.debug.breakpoint 仅仅是 jax.debug.callback 的一个应用,它与 jax.debug.print 的关键点 一样,但也有一些额外的注意事项

  • jax.debug.breakpointjax.debug.print 生成更多中间结果,因为它强制将调用栈中的所有值都物化。

  • jax.debug.breakpointjax.debug.print 具有更高的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。

jax.debug.breakpoint() 的优缺点#

优点#

  • 简单、直观且(某种程度上)标准

  • 可以同时检查调用栈上下文的多个值

局限性#

  • 可能需要使用多个断点来精确定位错误源

  • 会生成很多中间结果