调试简介#
本节向您介绍一组内置的 JAX 调试方法 — jax.debug.print()
、jax.debug.breakpoint()
和 jax.debug.callback()
— 您可以将它们与各种 JAX 转换一起使用。
让我们从 jax.debug.print()
开始。
jax.debug.print
用于简单检查#
这是一个经验法则
对于具有
jax.jit()
、jax.vmap()
等的跟踪(动态)数组值,请使用jax.debug.print()
。对于静态值(例如数据类型和数组形状),请使用 Python 的
print()
。
回顾 即时编译 中的内容,当使用 jax.jit()
转换函数时,Python 代码会使用抽象跟踪器来代替你的数组。因此,Python 的 print()
函数只会打印此跟踪器值。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Python 的 print
在跟踪时间执行,也就是运行时值存在之前。如果你想打印实际的运行时值,可以使用 jax.debug.print()
。
@jax.jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
类似地,在 jax.vmap()
中,使用 Python 的 print
只会打印跟踪器;要打印正在映射的值,请使用 jax.debug.print()
。
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
xs = jnp.arange(3.)
result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314
以下是使用 jax.lax.map()
的结果,它是一个顺序映射而不是向量化。
result = jax.lax.map(f, xs)
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
注意顺序不同,因为 jax.vmap()
和 jax.lax.map()
以不同的方式计算相同的结果。在调试时,评估顺序细节正是你可能需要检查的内容。
下面是一个使用 jax.grad()
的示例,其中 jax.debug.print()
仅打印前向传递。在这种情况下,行为类似于 Python 的 print()
,但如果你在调用期间应用 jax.jit()
,则行为保持一致。
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0
有时,当参数彼此不依赖时,当使用 JAX 转换进行分阶段时,对 jax.debug.print()
的调用可能会以不同的顺序打印它们。如果你需要原始顺序,例如 x: ...
首先,然后 y: ...
其次,请添加 ordered=True
参数。
例如
@jax.jit
def f(x, y):
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)
要了解有关 jax.debug.print()
及其细节的更多信息,请参阅 高级调试。
jax.debug.breakpoint
用于类似 pdb
的调试#
总结:使用 jax.debug.breakpoint()
暂停 JAX 程序的执行以检查值。
为了在调试期间的某些点暂停已编译的 JAX 程序,可以使用 jax.debug.breakpoint()
。提示符类似于 Python pdb
,它允许你检查调用栈中的值。事实上,jax.debug.breakpoint()
是 jax.debug.callback()
的一个应用,它捕获有关调用栈的信息。
要在 breakpoint
调试会话期间打印所有可用命令,请使用 help
命令。(完整的调试器命令、细节、其优点和局限性在 高级调试 中介绍)。
以下是一个调试器会话可能是什么样子的示例。
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution
对于依赖于值的断点,可以使用运行时条件,例如 jax.lax.cond()
。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution
jax.debug.callback
用于在调试期间进行更多控制#
jax.debug.print()
和 jax.debug.breakpoint()
都是使用更灵活的 jax.debug.callback()
实现的,它可以通过 Python 回调提供对主机端逻辑的更多控制。它与 jax.jit()
、jax.vmap()
、jax.grad()
和其他转换兼容(有关更多信息,请参阅 外部回调 中的 回调类型 表格)。
例如
import logging
def log_value(x):
logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
WARNING:root:Logged value: 1.0
此回调与其他转换兼容,包括 jax.vmap()
和 jax.grad()
。
x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0
这使得 jax.debug.callback()
对于通用调试非常有用。
你可以在 外部回调 中了解有关 jax.debug.callback()
和其他类型的 JAX 回调的更多信息。
后续步骤#
查看 高级调试 以了解有关 JAX 中调试的更多信息。