编译后的打印和断点#
jax.debug
包提供了一些用于检查编译函数内部值的实用工具。
使用 jax.debug.print
和其他调试回调进行调试#
总结:使用 jax.debug.print()
将追踪到的数组值打印到已编译(例如,用 jax.jit
或 jax.pmap
装饰的)函数的标准输出中
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
通过一些转换,例如 jax.grad
和 jax.vmap
,你可以使用 Python 的内置函数 print
来打印数值。但 print
无法与 jax.jit
或 jax.pmap
一起使用,因为这些转换会延迟数值计算。所以请使用 jax.debug.print
代替!
从语义上讲,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
区别在于它可以被 JAX 阶段化和转换。有关更多详细信息,请参阅 API 参考
。
请注意,fmt
不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print
,我们希望延迟格式化直到稍后。
何时使用“debug”打印?#
你应该在 JAX 转换(如 jit
、vmap
等)中使用 jax.debug.print
来打印动态(即跟踪)数组值。对于静态值的打印(如数组形状或数据类型),你可以使用正常的 Python print
语句。
为什么是“debug”打印?#
为了调试,jax.debug.print
可以揭示有关计算如何执行的信息
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
请注意,打印的结果顺序不同!
通过揭示这些内部工作机制,jax.debug.print
的输出不遵守 JAX 的通常语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
以不同的方式计算相同的结果。然而,这些计算顺序细节正是我们调试时可能想要查看的内容!
因此,在调试时使用 jax.debug.print
,而不要在语义保证很重要的场合使用它。
更多 jax.debug.print
的示例#
除了上面使用 jit
和 vmap
的示例之外,这里还有一些需要牢记的示例。
在 jax.pmap
下打印#
当使用 jax.pmap
时,jax.debug.print
可能被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下打印#
在 jax.grad
下,jax.debug.print
仅在正向传播时打印
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
这种行为类似于 Python 的内置函数 print
在 jax.grad
下的工作方式。但是通过在这里使用 jax.debug.print
,即使调用者应用了 jax.jit
,行为也保持一致。
要在反向传播时打印,只需使用 jax.custom_vjp
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印#
jax.debug.print
也适用于其他转换,如 pjit
。
使用 jax.debug.callback
进行更多控制#
实际上,jax.debug.print
是对 jax.debug.callback
的一个简单方便的包装器,可以直接使用它来更好地控制字符串格式化,甚至输出类型。
从语义上讲,jax.debug.callback
大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print
一样,这些回调仅应用于调试输出,如打印或绘图。打印和绘图非常安全,但如果你将它用于其他任何用途,它在转换下的行为可能会让你感到意外。例如,使用 jax.debug.callback
进行计时操作是不安全的,因为回调可能会被重新排序并且是异步的(见下文)。
jax.debug.print
的优点和局限性#
优点#
打印调试简单直观
jax.debug.callback
可用于其他无害的副作用
局限性#
添加打印语句是一个手动过程
可能会对性能产生影响
使用 jax.debug.breakpoint()
进行交互式检查#
总结:使用 jax.debug.breakpoint()
暂停 JAX 程序的执行,以检查值
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是对 jax.debug.callback(...)
的应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print
相同的转换行为(例如,对 jax.debug.breakpoint()
进行 vmap
会在映射轴上展开它)。
用法#
在已编译的 JAX 函数中调用 jax.debug.breakpoint()
将在程序遇到断点时暂停程序。你将看到一个类似于 pdb
的提示,允许你检查调用堆栈中的值。与 pdb
不同,你将无法逐步执行,但可以恢复执行。
调试器命令
help
- 打印可用的命令p
- 评估表达式并打印结果pp
- 评估表达式并以格式化方式打印结果u(p)
- 上移一个栈帧d(own)
- 下移一个栈帧w(here)/bt
- 打印回溯信息l(ist)
- 打印代码上下文c(ont(inue))
- 恢复程序执行q(uit)/exit
- 退出程序(在 TPU 上不可用)
示例#
与 jax.lax.cond
的用法#
当与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
关键点#
由于 jax.debug.breakpoint
仅仅是 jax.debug.callback
的一个应用,它与 jax.debug.print
的关键点 一样,但也有一些额外的注意事项
jax.debug.breakpoint
比jax.debug.print
生成更多中间结果,因为它强制将调用栈中的所有值都物化。jax.debug.breakpoint
比jax.debug.print
具有更高的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优缺点#
优点#
简单、直观且(某种程度上)标准
可以同时检查调用栈上下文的多个值
局限性#
可能需要使用多个断点来精确定位错误源
会生成很多中间结果