编译后的打印和断点#

jax.debug 包提供了一些有用的工具,用于检查编译函数内部的值。

使用 jax.debug.print 和其他调试回调进行调试#

总结: 使用 jax.debug.print() 在编译(例如,jax.jitjax.pmap 装饰的)函数中将跟踪的数组值打印到标准输出

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

对于某些转换,例如 jax.gradjax.vmap,你可以使用 Python 内置的 print 函数来打印数值。但是 print 不能与 jax.jitjax.pmap 一起使用,因为这些转换会延迟数值的计算。因此,请改用 jax.debug.print

从语义上讲,jax.debug.print 大致等同于以下 Python 函数

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

只是它可以被 JAX 阶段化并转换。有关更多详细信息,请参阅 API 参考

请注意,fmt 不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print,我们希望延迟格式化直到稍后。

何时使用“debug”打印?#

你应该在 JAX 转换(如 jitvmap 和其他转换)中使用 jax.debug.print 来打印动态(即跟踪的)数组值。对于打印静态值(如数组形状或 dtypes),你可以使用普通的 Python print 语句。

为什么使用“debug”打印?#

为了调试,jax.debug.print 可以揭示关于计算如何进行评估的方式的信息

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

请注意,打印的结果顺序不同!

通过揭示这些内部工作原理,jax.debug.print 的输出不遵守 JAX 通常的语义保证,例如 jax.vmap(f)(xs)jax.lax.map(f, xs) 计算相同的东西(以不同的方式)。然而,这些评估顺序细节正是我们在调试时可能想要看到的!

因此,请使用 jax.debug.print 进行调试,而不是在语义保证很重要时使用。

更多关于 jax.debug.print 的示例#

除了上面使用 jitvmap 的示例外,这里还有一些需要记住的示例。

jax.pmap 下打印#

当使用 jax.pmap 时,jax.debug.print 可能会被重新排序!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0

jax.grad 下打印#

jax.grad 下,jax.debug.print 将仅在前向传播时打印

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

此行为类似于 Python 内置的 printjax.grad 下的工作方式。但是,通过在此处使用 jax.debug.print,即使调用者应用了 jax.jit,行为也相同。

要在反向传播时打印,只需使用 jax.custom_vjp

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0

在其他转换中打印#

jax.debug.print 也适用于其他转换,如 pjit

使用 jax.debug.callback 进行更多控制#

事实上,jax.debug.printjax.debug.callback 的一个简单的便捷包装器,可以直接使用它来更好地控制字符串格式化,甚至输出类型。

从语义上讲,jax.debug.callback 大致等同于以下 Python 函数

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

jax.debug.print 一样,这些回调应仅用于调试输出,如打印或绘图。打印和绘图是相当无害的,但如果你将其用于其他任何事情,其行为可能会在转换下让你感到惊讶。例如,使用 jax.debug.callback 进行计时操作是不安全的,因为回调可能会被重新排序和异步执行(请参阅下文)。

注意事项#

与大多数 JAX API 一样,如果你不小心,jax.debug.print 可能会给你带来麻烦。

打印结果的顺序#

当对 jax.debug.print 的不同调用涉及彼此不依赖的参数时,它们在被阶段化时可能会被重新排序,例如通过 jax.jit

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

为什么?在底层,编译器会获得阶段化计算的函数式表示,其中 Python 函数的命令式顺序丢失,只剩下数据依赖性。对于具有函数式纯代码的用户来说,这种变化是不可见的,但在存在诸如打印之类的副作用的情况下,它是可以注意到的。

要保留 Python 函数中编写的 jax.debug.print 的原始顺序,你可以使用 jax.debug.print(..., ordered=True),这将确保保留打印的相对顺序。但是使用 ordered=True 将在 jax.pmap 和其他涉及并行性的 JAX 转换下引发错误,因为无法保证并行执行下的顺序。

异步回调#

根据后端,jax.debug.print 可能会异步发生,即不在你的主程序线程中。这意味着即使你的 JAX 函数已返回值,值也可能会打印到你的屏幕上。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

要阻止函数中的 jax.debug.print,你可以调用 jax.effects_barrier(),它将等待直到函数中任何剩余的副作用也已完成。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>

性能影响#

不必要的物化#

虽然 jax.debug.print 旨在具有最小的性能开销,但它可能会干扰编译器优化并可能影响 JAX 程序的内存配置文件。

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

在此示例中,我们正在打印线性层和激活函数之间的中间值。诸如 XLA 之类的编译器可以执行融合优化,这可能会避免在内存中实例化 logits。但是当我们在 logits 上使用 jax.debug.print 时,我们正在强制实例化这些中间值,这可能会减慢程序速度并增加内存使用量。

此外,当将 jax.debug.printjax.pjit 一起使用时,会发生全局同步,这会将值实例化在单个设备上。

回调开销#

jax.debug.print 本质上会导致加速器与其主机之间的通信。底层机制因后端而异(例如,GPU 与 TPU),但在所有情况下,我们都需要将打印的值从设备复制到主机。在 CPU 情况下,此开销较小。

此外,当将 jax.debug.printjax.pjit 一起使用时,会发生全局同步,这会增加一些开销。

jax.debug.print 的优点和局限性#

优点#

  • 打印调试简单直观

  • jax.debug.callback 可用于其他无害的副作用

局限性#

  • 添加打印语句是一个手动过程

  • 可能对性能产生影响

使用 jax.debug.breakpoint() 进行交互式检查#

总结: 使用 jax.debug.breakpoint() 暂停 JAX 程序的执行以检查值

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() 实际上只是 jax.debug.callback(...) 的一个应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print 相同的转换行为(例如,vmap-ing jax.debug.breakpoint() 会将其在映射轴上展开)。

用法#

在编译的 JAX 函数中调用 jax.debug.breakpoint() 将在程序到达断点时暂停程序。你将看到一个类似 pdb 的提示符,允许你检查调用堆栈中的值。与 pdb 不同,你将无法逐步执行,但你可以恢复执行。

调试器命令

  • help - 打印可用的命令

  • p - 计算表达式并打印其结果

  • pp - 计算表达式并漂亮地打印其结果

  • u(p) - 上移一个堆栈帧

  • d(own) - 下移一个堆栈帧

  • w(here)/bt - 打印回溯

  • l(ist) - 打印代码上下文

  • c(ont(inue)) - 恢复程序执行

  • q(uit)/exit - 退出程序(在 TPU 上不起作用)

示例#

jax.lax.cond 的用法#

当与 jax.lax.cond 结合使用时,调试器可以成为检测 naninf 的有用工具。

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

需要注意的点#

因为 jax.debug.breakpoint 只是 jax.debug.callback 的一个应用,它与 jax.debug.print 有相同的需要注意的点,但有一些额外的注意事项

  • jax.debug.breakpointjax.debug.print 具体化更多 的中间值,因为它强制具体化调用堆栈中的所有值

  • jax.debug.breakpointjax.debug.print 具有更高的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。

jax.debug.breakpoint() 的优点和局限性#

优点#

  • 简单、直观且(某种程度上)标准

  • 可以同时检查调用堆栈中上下多个值

局限性#

  • 可能需要使用多个断点来精确定位错误的来源

  • 具体化许多中间值