jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[source]#
打印值并在分阶段的 JAX 函数中工作。
此函数不适用于 f-字符串,因为格式化是延迟的。因此,不要使用
jax.debug.print(f"hello {bar}")
,而应使用jax.debug.print("hello {bar}", bar=bar)
。此函数是围绕
jax.debug.callback()
的一个简单的便捷包装器。其实现本质上是def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
而不是此便捷包装器可能很有用。例如,要在日志中获取调试打印,您可以将jax.debug.callback()
与logging.log
一起使用。