jax.experimental.pallas.debug_print#
- jax.experimental.pallas.debug_print(fmt, *args)[源代码]#
打印 Pallas 内核内部的值。
- 参数:
fmt (str) –
一个要包含在输出中的格式字符串。对格式字符串的限制取决于后端
在 GPU 上,当使用 Triton 时,
fmt
不能包含任何占位符({...}
),因为它总是在任何值之前打印。在 GPU 上,当使用实验性的 Mosaic GPU 后端时,
fmt
必须包含每个要打印的值的占位符。不支持格式说明符和转换。所有值都必须是标量。在 TPU 上,如果所有输入都是标量:如果
fmt
包含占位符,则所有值都必须是 32 位整数。如果没有占位符,则值将在格式字符串之后打印。在 TPU 上,如果输入是单个向量,则该向量将在格式字符串之后打印。格式字符串必须以单个占位符
{}
结尾。
*args (jax.typing.ArrayLike) – 要打印的值。