jax.debug.breakpoint

内容

jax.debug.breakpoint#

jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, token=None, **kwargs)[source]#

在程序中的某个点进入断点。

参数:
  • backend (str | None | None) – 要使用的调试器后端。默认情况下,选择优先级最高的调试器,在没有其他注册的调试器的情况下,回退到 CLI 调试器。

  • filter_frames (bool) – 是否从回溯中过滤掉 JAX 内部堆栈帧。由于某些库(如 Flax)也使用 JAX 的堆栈帧过滤系统,因此此选项也会影响是否过滤掉库的堆栈帧。

  • num_frames (int | None | None) – 可用于在交互式调试器中检查的当前堆栈帧之上的帧数。

  • ordered (bool) – 一个关键字仅参数,用于指示分阶段计算是否强制执行此 jax.debug.breakpoint 相对于其他有序 jax.debug.breakpointjax.debug.print 调用的排序。

  • token – 关键字仅参数;作为 ordered 的替代方案。如果使用,则应传递 JAX 数组(或 JAX 数组的 pytree),并且断点将在其值计算完成后运行。该值将保持不变并应传回计算。如果返回值在后续计算中未被使用,则整个计算将被修剪,并且此断点将不会运行。

返回值:

如果传递了 token,则返回其值,保持不变。否则,返回 None