类型提升语义#

本文档描述了 JAX 的类型提升规则,即对于每对类型,jax.numpy.promote_types() 的结果。有关以下描述的设计考虑的背景信息,请参阅JAX 类型提升语义的设计

JAX 的类型提升行为由以下类型提升格决定

_images/type_lattice.svg

例如,

  • b1 表示 np.bool_

  • i2 表示 np.int16

  • u4 表示 np.uint32

  • bf 表示 np.bfloat16

  • f2 表示 np.float16

  • c8 表示 np.complex64

  • i* 表示 Python int 或弱类型 int

  • f* 表示 Python float 或弱类型 float,并且

  • c* 表示 Python complex 或弱类型 complex

(有关弱类型的更多信息,请参阅下面的 JAX 中的弱类型值)。

任意两种类型之间的提升由它们在此格上的给出,这会生成以下二进制提升表

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

JAX 的类型提升规则与 NumPy 的规则(由 numpy.promote_types() 给出)不同,上表中以绿色背景突出显示的单元格表示差异。主要有三类差异

  • 当将弱类型值与相同类别的已键入 JAX 值进行提升时,JAX 始终偏向于 JAX 值的精度。例如,jnp.int16(1) + 1 将返回 int16,而不是像 NumPy 那样提升到 int64。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上面的格进行类型提升。例如,jnp.int16(1) + np.array(1) 将返回 int64

  • 当将整数或布尔类型与浮点或复数类型进行提升时,JAX 始终偏向于浮点或复数类型的类型。

  • JAX 支持 bfloat16 非标准 16 位浮点类型(jax.numpy.bfloat16),该类型对于神经网络训练非常有用。唯一值得注意的提升行为是关于 IEEE-754 float16bfloat16 会将其提升为 float32

NumPy 和 JAX 之间的差异是由于加速器设备(例如 GPU 和 TPU)要么在使用 64 位浮点类型时会产生显著的性能损失(GPU),要么根本不支持 64 位浮点类型(TPU)。传统的 NumPy 提升规则过于倾向于过度提升到 64 位类型,这对于旨在在加速器上运行的系统来说是有问题的。

JAX 使用更适合现代加速器设备的浮点提升规则,并且在提升浮点类型方面不那么激进。JAX 用于浮点类型的提升规则类似于 PyTorch 使用的规则。

Python 运算符分派的影响#

请记住,像 + 这样的 Python 运算符将基于正在相加的两个值的 Python 类型进行分派。这意味着,例如,np.int16(1) + 1 将使用 NumPy 规则进行提升,而 jnp.int16(1) + 1 将使用 JAX 规则进行提升。当两种提升类型组合在一起时,这可能会导致潜在的令人困惑的非关联提升语义;例如,np.int16(1) + 1 + jnp.int16(1)

JAX 中的弱类型值#

在大多数情况下,JAX 中的弱类型值可以被视为具有与 Python 标量等效的提升行为,例如以下示例中的整数标量 2

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)

JAX 的弱类型框架旨在防止 JAX 值和没有显式用户指定类型的值(例如 Python 标量文字)之间的二进制操作中出现不需要的类型提升。例如,如果 2 不被视为弱类型,则上面的表达式会导致隐式类型提升

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)

在 JAX 中使用时,Python 标量有时会提升为 DeviceArray 对象,例如在 JIT 编译期间。为了在这种情况下保持所需的提升语义,DeviceArray 对象带有 weak_type 标志,该标志可以在数组的字符串表示中看到

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)

如果显式指定了 dtype,则它将改为产生标准的强类型数组值

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)

严格的 dtype 提升#

在某些情况下,禁用隐式类型提升行为并改为要求所有提升都是显式的可能会很有用。这可以通过将 jax_numpy_dtype_promotion 标志设置为 'strict' 在 JAX 中完成。在本地,可以使用上下文管理器来完成

>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + y  
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

为方便起见,严格提升模式仍然允许安全的弱类型提升,因此您仍然可以编写混合 JAX 数组和 Python 标量的代码

>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + 1
>>> print(z)
2.0

如果您希望全局设置配置,可以使用标准配置更新来完成

jax.config.update('jax_numpy_dtype_promotion', 'strict')

要恢复默认标准类型提升,请将此配置设置为 'standard'

jax.config.update('jax_numpy_dtype_promotion', 'standard')