jax.default_matmul_precision#
- jax.default_matmul_precision = <jax._src.config.State object>#
用于 jax_default_matmul_precision 配置选项的上下文管理器。
控制 32 位输入的默认矩阵乘法和卷积精度。
某些平台(例如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,在速度和精度之间进行权衡。可以为每个操作控制精度;例如,请参阅
jax.lax.conv_general_dilated()
和jax.lax.dot()
文档字符串。但是,控制在操作未指定特定精度时获得的默认行为可能很有用。此选项可用于控制参与 32 位输入的矩阵乘法和卷积计算的默认精度级别。这些级别大致描述了计算标量积的精度。‘bfloat16’ 选项最快但精度最低;‘float32’ 与完整的 float32 精度类似;‘tensorfloat32’ 是介于两者之间。
- 参数:
new_val (Any)