JAX 调试标志#
JAX 提供了标志和上下文管理器,可以更轻松地捕获错误。
jax_debug_nans
配置选项和上下文管理器#
**摘要:**启用 jax_debug_nans
标志以自动检测何时在 jax.jit
编译的代码中生成 NaN(但在 jax.pmap
或 jax.pjit
编译的代码中不会)。
jax_debug_nans
是一个 JAX 标志,启用后,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理 - 当从 JIT 编译的函数中检测到 NaN 输出时,该函数将以急切方式(即不进行编译)重新运行,并在生成 NaN 的特定原语处抛出错误。
用法#
如果要跟踪函数或梯度中 NaN 的出现位置,可以通过以下方式打开 NaN 检查器:
设置
JAX_DEBUG_NANS=True
环境变量;在主文件顶部附近添加
jax.config.update("jax_debug_nans", True)
;在你的主文件中添加
jax.config.parse_flags_with_absl()
,然后使用命令行标志(例如--jax_debug_nans=True
)设置选项;
示例#
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
jax_debug_nans
的优势和局限性#
优势#
易于应用
精确检测 NaN 的产生位置
抛出标准 Python 异常,并与 PDB 事后调试兼容
局限性#
与
jax.pmap
或jax.pjit
不兼容重新以 Eager 模式运行函数可能会很慢
错误地报告假阳性(例如,有意创建的 NaN)
jax_disable_jit
配置选项和上下文管理器#
摘要:启用 jax_disable_jit
标志以禁用 JIT 编译,从而可以使用传统的 Python 调试工具,例如 print
和 pdb
jax_disable_jit
是一个 JAX 标志,启用后,它会禁用整个 JAX 中的 JIT 编译(包括在控制流函数中,如 jax.lax.cond
和 jax.lax.scan
)。
用法#
您可以通过以下方式禁用 JIT 编译:
设置
JAX_DISABLE_JIT=True
环境变量;在主文件顶部附近添加
jax.config.update("jax_disable_jit", True)
;在你的主文件中添加
jax.config.parse_flags_with_absl()
,然后使用命令行标志(例如--jax_disable_jit=True
)设置选项;
示例#
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
jax_disable_jit
的优势和局限性#
优势#
易于应用
可以使用 Python 内置的
breakpoint
和print
抛出标准 Python 异常,并与 PDB 事后调试兼容
局限性#
与
jax.pmap
或jax.pjit
不兼容在没有 JIT 编译的情况下运行函数可能会很慢