编译后的打印和断点#
jax.debug
包提供了一些有用的工具,用于检查编译函数内部的值。
使用 jax.debug.print
和其他调试回调进行调试#
总结: 使用 jax.debug.print()
在编译(例如,jax.jit
或 jax.pmap
装饰的)函数中将跟踪的数组值打印到标准输出
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
对于某些转换,例如 jax.grad
和 jax.vmap
,你可以使用 Python 内置的 print
函数来打印数值。但是 print
不能与 jax.jit
或 jax.pmap
一起使用,因为这些转换会延迟数值的计算。因此,请改用 jax.debug.print
!
从语义上讲,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
只是它可以被 JAX 阶段化并转换。有关更多详细信息,请参阅 API 参考
。
请注意,fmt
不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print
,我们希望延迟格式化直到稍后。
何时使用“debug”打印?#
你应该在 JAX 转换(如 jit
,vmap
和其他转换)中使用 jax.debug.print
来打印动态(即跟踪的)数组值。对于打印静态值(如数组形状或 dtypes),你可以使用普通的 Python print
语句。
为什么使用“debug”打印?#
为了调试,jax.debug.print
可以揭示关于计算如何进行评估的方式的信息
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
请注意,打印的结果顺序不同!
通过揭示这些内部工作原理,jax.debug.print
的输出不遵守 JAX 通常的语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
计算相同的东西(以不同的方式)。然而,这些评估顺序细节正是我们在调试时可能想要看到的!
因此,请使用 jax.debug.print
进行调试,而不是在语义保证很重要时使用。
更多关于 jax.debug.print
的示例#
除了上面使用 jit
和 vmap
的示例外,这里还有一些需要记住的示例。
在 jax.pmap
下打印#
当使用 jax.pmap
时,jax.debug.print
可能会被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下打印#
在 jax.grad
下,jax.debug.print
将仅在前向传播时打印
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
此行为类似于 Python 内置的 print
在 jax.grad
下的工作方式。但是,通过在此处使用 jax.debug.print
,即使调用者应用了 jax.jit
,行为也相同。
要在反向传播时打印,只需使用 jax.custom_vjp
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印#
jax.debug.print
也适用于其他转换,如 pjit
。
使用 jax.debug.callback
进行更多控制#
事实上,jax.debug.print
是 jax.debug.callback
的一个简单的便捷包装器,可以直接使用它来更好地控制字符串格式化,甚至输出类型。
从语义上讲,jax.debug.callback
大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print
一样,这些回调应仅用于调试输出,如打印或绘图。打印和绘图是相当无害的,但如果你将其用于其他任何事情,其行为可能会在转换下让你感到惊讶。例如,使用 jax.debug.callback
进行计时操作是不安全的,因为回调可能会被重新排序和异步执行(请参阅下文)。
jax.debug.print
的优点和局限性#
优点#
打印调试简单直观
jax.debug.callback
可用于其他无害的副作用
局限性#
添加打印语句是一个手动过程
可能对性能产生影响
使用 jax.debug.breakpoint()
进行交互式检查#
总结: 使用 jax.debug.breakpoint()
暂停 JAX 程序的执行以检查值
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是 jax.debug.callback(...)
的一个应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print
相同的转换行为(例如,vmap
-ing jax.debug.breakpoint()
会将其在映射轴上展开)。
用法#
在编译的 JAX 函数中调用 jax.debug.breakpoint()
将在程序到达断点时暂停程序。你将看到一个类似 pdb
的提示符,允许你检查调用堆栈中的值。与 pdb
不同,你将无法逐步执行,但你可以恢复执行。
调试器命令
help
- 打印可用的命令p
- 计算表达式并打印其结果pp
- 计算表达式并漂亮地打印其结果u(p)
- 上移一个堆栈帧d(own)
- 下移一个堆栈帧w(here)/bt
- 打印回溯l(ist)
- 打印代码上下文c(ont(inue))
- 恢复程序执行q(uit)/exit
- 退出程序(在 TPU 上不起作用)
示例#
与 jax.lax.cond
的用法#
当与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
需要注意的点#
因为 jax.debug.breakpoint
只是 jax.debug.callback
的一个应用,它与 jax.debug.print
有相同的需要注意的点,但有一些额外的注意事项
jax.debug.breakpoint
比jax.debug.print
具体化更多 的中间值,因为它强制具体化调用堆栈中的所有值jax.debug.breakpoint
比jax.debug.print
具有更高的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优点和局限性#
优点#
简单、直观且(某种程度上)标准
可以同时检查调用堆栈中上下多个值
局限性#
可能需要使用多个断点来精确定位错误的来源
具体化许多中间值