jax.debug.print#

jax.debug.print(fmt, *args, ordered=False, **kwargs)[源代码]#

打印值,并可以在分阶段执行的 JAX 函数中使用。

此函数适用于 f-strings,因为格式化会被延迟。所以,不要写 jax.debug.print(f"hello {bar}"),而是写 jax.debug.print("hello {bar}", bar=bar)

此函数是 jax.debug.callback() 的一个简单的便利包装器。其实现本质上是:

def debug_print(fmt: str, *args, **kwargs):
  jax.debug.callback(
      lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
      *args, **kwargs)

直接调用 jax.debug.callback() 而不是使用此便利包装器可能很有用。例如,为了在日志中获得调试输出,你可能会将 jax.debug.callback()logging.log 一起使用。

参数:
  • fmt (str) – 一个格式化字符串,例如 "hello {x}",将用于格式化输入参数,类似于 str.format。请参阅 Python 文档中的 字符串格式化格式字符串语法

  • *args – 要格式化的位置参数列表,就像传递给 fmt.format 一样。

  • ordered (bool) – 一个仅关键字参数,用于指示分阶段执行的计算是否将强制此 jax.debug.print 与其他有序的 jax.debug.print 调用保持顺序。

  • **kwargs – 要格式化的其他关键字参数,就像传递给 fmt.format 一样。

返回类型:

None