类型提升语义#
本文档描述了 JAX 的类型提升规则,即每对类型的 jax.numpy.promote_types()
的结果。关于下方所述设计考虑的一些背景信息,请参阅 JAX 类型提升语义的设计。
JAX 的类型提升行为通过以下类型提升格确定
例如:
b1
表示np.bool_
,i2
表示np.int16
,u4
表示np.uint32
,bf
表示np.bfloat16
,f2
表示np.float16
,c8
表示np.complex64
,i*
表示 Pythonint
或弱类型int
,f*
表示 Pythonfloat
或弱类型float
,以及c*
代表 Pythoncomplex
或弱类型complex
。
(有关弱类型的更多信息,请参见下面 JAX 中的弱类型值)。
任何两种类型之间的提升由它们在此格子上进行的 连接 给出,这会生成以下二元提升表
b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u1 | f* | c* |
u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u2 | f* | c* |
u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u4 | f* | c* |
u8 | u8 | u8 | u8 | u8 | u8 | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | u8 | f* | c* |
i1 | i1 | i2 | i4 | i8 | f* | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i1 | f* | c* |
i2 | i2 | i2 | i4 | i8 | f* | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i2 | f* | c* |
i4 | i4 | i4 | i4 | i8 | f* | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i4 | f* | c* |
i8 | i8 | i8 | i8 | i8 | f* | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i8 | f* | c* |
bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c8 | c16 | bf | bf | c8 |
f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c8 | c16 | f2 | f2 | c8 |
f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c8 | c16 | f4 | f4 | c8 |
f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c16 | c16 | f8 | f8 | c16 |
c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c16 | c8 | c16 | c8 | c8 | c8 |
c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 |
i* | i* | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c8 | c8 | c8 | c16 | c8 | c16 | c* | c* | c* |
JAX 的类型提升规则与 NumPy 的不同,如 numpy.promote_types()
所示,上述表格中用绿色背景突出显示的单元格即为区别所在。主要有三种类型的区别
在将弱类型值与同一类别的 JAX 类型值提升时,JAX 始终优先选择 JAX 值的精度。例如,
jnp.int16(1) + 1
将返回int16
而不是像 NumPy 中那样提升为int64
。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上述格点进行类型提升。例如,jnp.int16(1) + np.array(1)
将返回int64
。在将整数或布尔类型与浮点类型或复数类型提升时,JAX 始终优先选择浮点类型或复数类型的类型。
JAX 支持 bfloat16 非标准 16 位浮点类型 (
jax.numpy.bfloat16
),这在神经网络训练中很有用。唯一值得注意的提升行为是相对于 IEEE-754float16
,bfloat16
会提升为float32
。
NumPy 和 JAX 之间的差异源于这样一个事实,即加速器设备(例如 GPU 和 TPU)要么在使用 64 位浮点类型时会产生很大的性能损失(GPU),要么根本不支持 64 位浮点类型(TPU)。经典的 NumPy 的提升规则过于倾向于过度提升为 64 位类型,这对于旨在在加速器上运行的系统来说是一个问题。
JAX 使用更适合现代加速器设备的浮点提升规则,并且对提升浮点类型没有那么激进。JAX 用于浮点类型的提升规则类似于 PyTorch 使用的提升规则。
Python 运算符分派的效应#
请记住,像 + 这样的 Python 运算符将根据要添加的两个值的 Python 类型进行分派。这意味着,例如,np.int16(1) + 1
将使用 NumPy 规则进行提升,而 jnp.int16(1) + 1
将使用 JAX 规则进行提升。当两种提升方式组合在一起时,这可能会导致潜在的令人困惑的非关联性提升语义;例如,np.int16(1) + 1 + jnp.int16(1)
。
JAX 中的弱类型值#
在大多数情况下,JAX 中的弱类型值可以被认为具有与 Python 标量等效的提升行为,例如以下代码中的整数标量 2
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)
JAX 的弱类型框架旨在防止在 JAX 值和没有明确用户指定类型的值(例如 Python 标量字面量)之间的二元运算中发生意外的类型提升。例如,如果 2
未被视为弱类型,则上述表达式将导致隐式类型提升
>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)
在 JAX 中使用时,Python 标量有时会被提升为 DeviceArray
对象,例如在 JIT 编译期间。为了在这种情况下保持所需的提升语义,DeviceArray
对象带有 weak_type
标志,可以在数组的字符串表示中看到该标志
>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)
如果显式指定 dtype
,它将改为生成标准的强类型数组值
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
严格的 dtype 提升#
在某些情况下,禁用隐式类型提升行为,而是要求所有提升都必须显式,这样做会很有用。这可以通过在 JAX 中将 jax_numpy_dtype_promotion
标志设置为 'strict'
来实现。在本地,可以使用上下文管理器来实现
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
为了方便起见,严格提升模式仍然允许安全的弱类型提升,因此您仍然可以编写混合了 JAX 数组和 Python 标量的代码
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
如果您希望全局设置配置,可以使用标准配置更新来实现
jax.config.update('jax_numpy_dtype_promotion', 'strict')
要恢复默认的标准类型提升,将此配置设置为 'standard'
jax.config.update('jax_numpy_dtype_promotion', 'standard')