jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[源代码]#
打印值,并可以在分阶段执行的 JAX 函数中使用。
此函数不适用于 f-strings,因为格式化会被延迟。所以,不要写
jax.debug.print(f"hello {bar}")
,而是写jax.debug.print("hello {bar}", bar=bar)
。此函数是
jax.debug.callback()
的一个简单的便利包装器。其实现本质上是:def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
而不是使用此便利包装器可能很有用。例如,为了在日志中获得调试输出,你可能会将jax.debug.callback()
与logging.log
一起使用。