JAX 类型提升语义的设计#

Open in Colab Open in Kaggle

Jake VanderPlas,2021 年 12 月

任何数值计算库在设计中面临的挑战之一是如何处理不同类型值之间的运算。本文档概述了 JAX 所使用的提升语义背后的思路,总结在 JAX 类型提升语义中。

JAX 类型提升的目标#

JAX 的数值计算 API 仿照 NumPy,并进行了一些增强,包括支持 GPU 和 TPU 等加速器。这使得采用 NumPy 的类型提升系统对 JAX 用户不利:NumPy 的类型提升规则严重倾向于 64 位输出,这对在加速器上进行计算来说是有问题的。GPU 和 TPU 等设备通常在使用 64 位浮点类型时会产生明显的性能损失,并且在某些情况下根本不支持原生的 64 位浮点类型。

32 位整数和浮点数之间的二元运算可以简单地看出这种有问题的类型提升语义

import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')

NumPy 倾向于产生 64 位值是使用 NumPy 的 API 进行加速器计算的长期存在的问题,目前还没有好的解决方案。因此,JAX 试图重新思考针对加速器的 NumPy 式类型提升。

回顾:表格和格#

在我们深入细节之前,让我们花点时间回顾一下如何思考类型提升的问题。考虑 Python 中内置数值类型之间的算术运算,即 intfloatcomplex 类型。通过几行代码,我们可以生成 Python 用于这些类型值之间加法的类型提升表

import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
             index=[name(t) for t in types], columns=[name(t) for t in types])
整数 浮点数 复数
整数 整数 浮点数 复数
浮点数 浮点数 浮点数 复数
复数 复数 复数 复数

此表枚举了 Python 的数值类型提升行为,但事实证明,还有一种更紧凑的表示形式:表示,其中任何两个节点之间的上确界是它们提升到的类型。Python 的提升表的格表示更简单

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)
../_images/818a3cf499d15c3be1d4c116db142da0418c174873f21e1ffcde679c6058f918.png

这个格是上面提升表中信息的紧凑编码。您可以通过在图中追踪到两个节点的第一个公共子节点(包括节点本身)来找到两个输入的类型提升结果;在数学上,这个公共子节点被称为格上这对节点的上确界最小上界;在这里我们将此操作称为

从概念上讲,箭头表示允许在源和目标之间进行隐式类型提升:例如,允许从整数隐式提升到浮点数,但不允许从浮点数隐式提升到整数。

请记住,通常并非每个有向无环图 (DAG) 都满足格的属性。格要求每对节点都存在唯一的最小上界;例如,以下两个 DAG 不是格

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 2))

lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])

lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);
../_images/a0acbd07f9486d95c10a36c11301d528fb7e65d671d622226151c431b3e36c62.png

左侧的 DAG 不是格,因为节点 BC 不存在上界;右侧的 DAG 在两个方面都失败了:首先,节点 CD 不存在上界,对于节点 AB,最小上界无法唯一确定:CD 都是候选者,但它们是无法排序的。

类型提升格的性质#

根据格指定类型提升可确保许多有用的属性。用 \(\vee\) 运算符表示格上的并,我们有

存在性:根据定义,格要求每对元素都存在唯一的格并:\(\forall (a, b): \exists !(a \vee b)\)

交换性:格并具有交换性:\(\forall (a, b): a\vee b = b \vee a\)

结合性:格并具有结合性:\(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\)

另一方面,这些属性意味着它们可以表示的类型提升系统存在限制;特别地,并非每个类型提升表都可以用格表示。一个现成的例子是 NumPy 的完整类型提升表;这可以通过反例快速证明:以下是 NumPy 中提升行为不具有结合性的三种标量类型

import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16

这样的结果可能会让用户感到惊讶:我们通常期望数学表达式映射到数学概念,因此,例如,a + b + c 应该等同于 c + b + ax * (y + z) 应该等同于 x * y + x * z。如果类型提升不具有结合性或交换性,则这些属性不再适用。

此外,与基于表格的系统相比,基于格的类型提升系统在概念化和理解方面更简单。例如,JAX 识别 18 种不同的类型:一个包含 18 个节点且节点之间具有稀疏、合理连接的提升格比包含 324 个条目的表格更容易记在脑海中。

因此,我们选择为 JAX 使用基于格的类型提升系统。

类别内的类型提升#

数值计算库通常提供的不仅仅是 intfloatcomplex;在每个类别中,都有各种可能的精度,用数值表示中使用的位数表示。我们将在此考虑的类别包括

  • 无符号整数,包括 uint8uint16uint32uint64(我们将使用 u8u16u32u64 作为简称)

  • 有符号整数,包括 int8int16int32int64(我们将使用 i8i16i32i64 作为简称)

  • 浮点数,包括 float16float32float64(我们将使用 f16f32f64 作为简称)

  • 复数浮点数,包括 complex64complex128(我们将使用 c64c128 作为简称)

Numpy 在这四个类别的类型提升语义相对简单:类型的有序层次结构直接转换为四个单独的格,表示类别内的类型提升规则

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/2d8495bcb006c34b42eeb4f3e0c6530fdef0bd7364c56184993925f0cf157abc.png

就 JAX 试图避免的将值提升为 64 位而言,每个类型类别内的这些同类提升语义没有问题:产生 64 位输出的唯一方法是拥有一个 64 位输入。

Python 标量的引入#

现在让我们思考一下 Python 标量如何融入其中。

在 NumPy 中,提升行为因输入是数组还是标量而异。例如,当对两个标量进行运算时,适用正常的提升规则

x = np.int8(0)  # int8 scalar
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int64')

在这里,Python 值 1 被视为 int64,并且直接的类别内规则导致 int64 结果。

但是,在 Python 标量和 NumPy 数组之间的运算中,标量会遵循数组的 dtype。例如

x = np.zeros(1, dtype='int8')  # int8 array
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int8')

在这里,int64 标量的位宽被忽略,而遵循数组的位宽。

这里还有另一个细节:当 NumPy 类型提升涉及标量时,输出 dtype 取决于值:如果 Python 标量对于给定的 dtype 而言太大,则会将其提升为兼容的类型

x = np.zeros(1, dtype='int8')  # int8 array
y = 1000  # int64 scalar
(x + y).dtype
dtype('int16')

对于 JAX 来说,值依赖的类型提升是不可行的,因为 JIT 编译和其他转换的本质是在不参考数据值的情况下对数据的抽象表示进行操作。

忽略值依赖的影响,NumPy 的类型提升中有符号整数分支可以用以下格表示,其中我们将使用 * 来标记标量数据类型

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
  'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3);
../_images/7e8c3295e403209560d8e142c5c830d79456a4e6d207dd1a7e4d15b55c56006b.png

uintfloatcomplex 格中也存在类似的模式。

为了简单起见,让我们将每种标量类型都归纳为一个节点,分别用 u*i*f*c* 表示。我们的类别内格现在可以表示如下

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
  'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
  'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
  'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
  'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/0fbe0c20cd350821e64f3742aa7864ec729565572b136950042095881672fdb9.png

从某种意义上说,将标量放在最左边是一种奇怪的选择:标量类型可以包含任何宽度的值,但是当与给定类型的数组交互时,提升结果会遵循数组类型。这样做的好处是,当您对数组 x 执行 x + 2 之类的操作时,无论其宽度如何,x 的类型都会传递到结果中。

for dtype in [np.int8, np.int16, np.int32, np.int64]:
  x = np.arange(10, dtype=dtype)
  assert (x + 2).dtype == dtype

此行为为我们为标量值使用 * 符号提供了动机:* 让人想起一个可以采用任何所需值的通配符。

这些语义的好处是,您可以轻松地使用简洁的 Python 代码表达一系列操作,而无需将标量显式转换为适当的类型。想象一下,如果不是这样写

3 * (x + 1) ** 2

您必须这样写

np.int32(3) * (x + np.int32(1)) ** np.int32(2)

尽管它是显式的,但数值代码将变得难以阅读或编写。使用上述标量提升语义,给定一个类型为 int32 的数组 x,则第二条语句中的类型在第一条语句中是隐式的。

组合格#

回想一下,我们首先介绍了表示 Python 中类型提升的格:int -> float -> complex。让我们将其重写为 i* -> f* -> c*,并且进一步允许 i* 包含 u*(毕竟,Python 中没有无符号整数标量类型)。

将所有这些放在一起,我们得到以下表示 Python 标量和 numpy 数组之间类型提升的部分格

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/796586be87180b0de3171d39763f2d33a80a641b72d82c00f0c0e352f754f201.png

请注意,这还不是一个真正的格:对于许多节点对,不存在并。但是,我们可以将其视为一个部分格,其中某些节点对没有定义的提升行为,并且这个部分格的定义部分正确地描述了 NumPy 的数组提升行为(忽略上面提到的值依赖语义)。

这为我们提供了一个很好的框架,通过在该图中添加连接来考虑填充这些未定义的提升规则。但是要添加哪些连接呢?总的来说,我们希望任何额外的连接都满足以下几个属性

  1. 类型提升应满足交换性和结合性:换句话说,该图应保持为(部分)格。

  2. 类型提升绝不应允许删除数据的整个组成部分:例如,我们绝不应将 complex 提升为 float,因为它会丢弃任何虚部。

  3. 类型提升绝不应导致未处理的溢出。例如,最大的 uint32 是最大的 int32 的两倍,因此我们不应隐式地将 uint32 提升为 int32

  4. 在可能的情况下,类型提升应避免精度损失。例如,int64 值可能具有 64 位尾数,因此将 int64 提升为 float64 可能导致精度损失。但是,最大的可表示 float64 大于最大的可表示 int64,因此在这种情况下,标准 #3 仍然得到满足。

  5. 在可能的情况下,二进制提升应避免产生比输入更宽的类型。这是为了确保 JAX 的隐式提升对基于加速器的工作流程保持友好,在这些工作流程中,用户通常希望将类型限制为 32 位(或在某些情况下为 16 位)值。

格上的每个新连接都为用户引入了一定程度的便利性(一组无需显式转换即可交互的新类型),但是如果违反上述任何标准,这种便利性可能会变得代价过高。开发完整的提升格需要在这种便利性和这种代价之间取得平衡。

混合类型提升:浮点型和复数型#

让我们从也许是最简单的情况开始,即浮点值和复数值之间的提升。

复数由成对的浮点数组成,因此我们在这两者之间有自然的提升路径:将浮点数转换为复数,同时保持实部的宽度。用我们的部分格表示,它看起来像这样

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/bf87909b2344aed80590d1c6d91585a02b25898ac217526cb49948d91205318f.png

事实证明,这恰好表示 Numpy 在混合浮点/复数类型提升中使用的语义。

混合类型提升:有符号整数和无符号整数#

对于下一种情况,让我们考虑更困难的事情:有符号整数和无符号整数之间的提升。例如,在将 uint8 提升为有符号整数时,我们需要多少位?

乍一看,您可能会认为将 uint8 提升为 int8 是很自然的;但是 uint8 中最大的数字在 int8 中是不可表示的。因此,将无符号整数提升为位数加倍的整数更有意义;可以通过在提升格中添加以下连接来表示此提升行为

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/3be7e17889458ac823bb5dacf31525c0d96578c6854962f45dcc60ec987a30bd.png

同样,此处添加的连接正是 Numpy 为混合整数提升实现的提升语义。

如何处理 uint64#

混合有符号/无符号整数提升的方法遗漏了一种类型:uint64。按照上面的模式,涉及 uint64 的混合整数运算的输出应产生 int128,但这并不是标准的可用数据类型。

Numpy 在此处的选择是提升为 float64

(np.uint64(1) + np.int64(1)).dtype
dtype('float64')

但是,这可能是一个令人惊讶的约定:这是整数类型提升不产生整数的唯一情况。现在,我们将使 uint64 的提升未定义,并在以后返回到此。

混合类型提升:整数和浮点型#

当将整数提升为浮点数时,我们可能会以与有符号整数和无符号整数之间混合提升相同的思路开始。一个 16 位的有符号或无符号整数不能由一个只有 10 位尾数的 16 位浮点数完整精度地表示。因此,将整数提升为位数加倍表示的浮点数可能是有意义的

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/8b3247e8189fbfad46a7e5583b636866fc45576e07c9bfd904457926306299d1.png

这实际上就是 Numpy 类型提升所做的,但是这样做会打破该图的格属性:例如,{i8, u8} 对不再具有唯一的最小上界:可能性是 i16f16,它们在图上是不可排序的。事实证明,这是上述 NumPy 非关联类型提升的根源。

我们能否对 NumPy 的提升规则进行修改,使其满足格属性,同时为混合类型提升提供合理的结果?我们可以采取以下几种方法。

选项 0:将整数/浮点混合精度保留为未定义#

为了使行为完全可预测(以牺牲用户便利性为代价),一种合理的选择是将超出 Python 标量的任何混合整数/浮点提升保留为未定义,并以上一节中的部分格停止。缺点是要求用户在整数和浮点数量之间进行运算时显式进行类型转换。

选项 1:避免所有精度损失#

如果我们的重点是以不惜一切代价避免精度损失,我们可以通过无符号整数现有的有符号整数路径将其提升为浮点数,从而恢复格属性

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/1eda89d008a8c6dadf926229bf9f2245722006c5bc1c42961c555a2595c95117.png

这种方法的一个缺点是,它仍然使 int64uint64 的提升未定义,因为没有标准的浮点类型具有足够的尾数位来表示它们的全部取值范围。我们可以放宽精度约束,并通过建立 i64->f64u64->f64 的连接来完成格结构,但这些连接将与此提升方案的动机背道而驰。

第二个缺点是,这个格结构使得很难找到一个合适的位置来插入 bfloat16(见下文),同时保持格的属性。

这种方法的第三个缺点,对于 JAX 的加速器后端来说更重要,是一些操作会导致类型比必要的更宽;例如,uint16float16 之间的混合操作会一直提升到 float64,这并不理想。

选项 2:避免大多数不必要的更宽的类型提升#

为了解决不必要地提升到更宽类型的问题,我们可以接受整数/浮点提升中存在一些精度损失的可能性,将有符号整数提升到相同宽度的浮点数。

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/f41cee38a476bf636be901e7f64a5dc3687002f9d12532ab706b9077d602b175.png

虽然这确实允许整数和浮点数之间存在精度损失的提升,但这些提升不会错误地表示结果的大小:虽然浮点尾数不够宽以表示所有值,但指数足够宽以近似它们。

这种方法还允许从 int64float64 的自然提升路径,尽管 uint64 在此方案中仍然不可提升。也就是说,从 u64f64 的连接比以前更容易被证明是合理的。

这种提升方案仍然会导致一些比必要的更宽的提升路径;例如,float32uint32 之间的操作会导致 float64。此外,这个格结构使得很难找到一个合适的位置来插入 bfloat16(见下文),同时保持格的属性。

选项 3:避免所有不必要的更宽的类型提升#

如果我们愿意从根本上改变我们对整数和浮点数提升的看法,我们可以避免所有不理想的 64 位提升。正如标量总是服从于数组类型的宽度一样,我们可以让整数总是服从于浮点类型的宽度。

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/d3f5e5be4354238a60698cb4f228d4e1f75a665577343c36b2c1ade1207783a0.png

这涉及到一个小技巧:之前我们使用 f* 来指代标量类型。在这个格结构中,f* 可能应用于混合计算的数组输出。与其将 f* 视为标量,不如将其视为一种特殊的 float 值,具有不同的提升规则:在 JAX 中,我们将其称为弱浮点数;见下文。

这种方法的优点是,除了无符号整数外,它避免了所有不必要的更宽的类型提升:如果没有 64 位输入,您永远不会得到 f64 输出,如果没有 32 位输入,您永远不会得到 f32 输出:这为在加速器上工作提供了方便的语义,同时避免了意外的 64 位值。

这种给予浮点类型优先级的特性类似于 PyTorch 的类型提升行为。这个格结构也恰好生成了一个非常类似于 JAX 最初的临时类型提升方案的提升表,该方案不是基于格结构,但具有给予浮点类型优先级的特性。

此外,这个格结构还提供了一个自然的位置来插入 bfloat16,而无需在 bf16f16 之间强制排序。

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],
  'c64': [3.5, 3], 'c128': [4.5, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/aa73688b580b02776fce218d6efe58792ae3b0976160a4b0c130b797780578af.png

这一点很重要,因为 f16bf16 是不可比较的,因为它们以不同的方式利用其位:bf16 以较低的精度表示更大的范围,而 f16 以较高的精度表示较小的范围。

然而,这些优点也带来了一些权衡。

  • 混合浮点数/整数提升非常容易导致精度损失:例如,int64(最大值为 \(9.2 \times 10^{18}\))可以提升到 float16(最大值为 \(6.5 \times 10^4\)),这意味着大多数可表示的值将变为 inf

  • 如上所述,f* 不能再被认为是“标量类型”,而是一种不同的 float64 类型。在 JAX 的术语中,这被称为弱类型,因为它以 64 位表示,但在与其他值提升时仅弱保持此位宽度。

请注意,这种方法仍然没有回答 uint64 的提升问题,尽管通过将 u64 连接到 f* 来关闭格结构可能是合理的。

JAX 中的类型提升#

在设计 JAX 的类型提升语义时,我们牢记了许多这些想法,并严重依赖于以下几点:

  1. 我们选择将 JAX 的类型提升语义限制为满足格属性的图:这是为了确保结合性和交换性,但也允许语义以 DAG 的形式紧凑地描述,而不是需要一个大的表格。

  2. 我们倾向于避免意外提升到更宽类型的语义,尤其是在涉及到浮点值时,以便有利于在加速器上进行计算。

  3. 如果需要保持 (1) 和 (2),我们很乐意接受混合类型提升中潜在的精度损失(但不是大小的损失)。

考虑到这一点,JAX 采用了选项 3。或者更确切地说,是选项 3 的稍微修改的版本,它在 u64f* 之间建立连接,以创建一个真正的格。为了清晰起见,重新排列节点后,JAX 的类型提升格如下所示:

隐藏代码单元格来源
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],
  'c64': [7.5, 1.5], 'c128': [8.5, 1.5],
}
fig, ax = plt.subplots(figsize=(10, 4))
ax.set_ylim(-0.5, 2)
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
# ax.patches[12].set_linestyle((0, (2, 4)))
../_images/d261add493a579484d9772634ce146f1240af3966d0845839c354417a3de2e53.png

此选择产生的行为在 JAX 类型提升语义中进行了总结。值得注意的是,除了包含更大的无符号类型(u16u32u64)以及一些关于标量/弱类型(i*f*c*)的行为的详细信息外,这种类型提升方案与 PyTorch 选择的方案非常接近。

对于那些感兴趣的人,下面的附录打印了 NumPy、Tensorflow、PyTorch 和 JAX 使用的完整提升表。

附录:示例类型提升表#

以下是各种 Python 数组计算库实现的隐式类型提升表的一些示例。

NumPy 类型提升#

请注意,NumPy 不包含 bfloat16 数据类型,并且下表忽略了依赖于值的影响。

隐藏代码单元格来源
# @title

import numpy as np
import pandas as pd
from IPython import display

np_dtypes = {
  'b': np.bool_,
  'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,
  'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,
  'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,
  'c64': np.complex64, 'c128': np.complex128,
  'i*': int, 'f*': float, 'c*': complex}

np_dtype_to_code = {val: key for key, val in np_dtypes.items()}

def make_np_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return np.zeros(1, dtype=dtype)

def np_result_code(dtype1, dtype2):
  try:
    out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return np_dtype_to_code[type(out)]
    else:
      return np_dtype_to_code[out.dtype.type]


grid = [[np_result_code(dtype1, dtype2)
         for dtype2 in np_dtypes.values()]
        for dtype1 in np_dtypes.values()]
table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 - f16 f32 f64 c64 c128 u8 f64 c128
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 - f32 f32 f64 c64 c128 u16 f64 c128
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 u32 f64 c128
u64 u64 u64 u64 u64 u64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 u64 f64 c128
i8 i8 i16 i32 i64 f64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i8 f64 c128
i16 i16 i16 i32 i64 f64 i16 i16 i32 i64 - f32 f32 f64 c64 c128 i16 f64 c128
i32 i32 i32 i32 i64 f64 i32 i32 i32 i64 - f64 f64 f64 c128 c128 i32 f64 c128
i64 i64 i64 i64 i64 f64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 i64 f64 c128
bf16 - - - - - - - - - - - - - - - - - -
f16 f16 f16 f32 f64 f64 f16 f32 f64 f64 - f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f64 f64 f32 f32 f64 f64 - f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c128 c128 c64 c64 c128 c128 - c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 - c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
f* f64 f64 f64 f64 f64 f64 f64 f64 f64 - f16 f32 f64 c64 c128 f64 f64 c128
c* c128 c128 c128 c128 c128 c128 c128 c128 c128 - c64 c64 c128 c64 c128 c128 c128 c128

Tensorflow 类型提升#

Tensorflow 避免定义隐式类型提升,除了在有限的情况下使用 Python 标量。该表是不对称的,因为在 tf.add(x, y) 中,y 的类型必须可以强制转换为 x 的类型。

隐藏代码单元格来源
# @title

import tensorflow as tf
import pandas as pd
from IPython import display

tf_dtypes = {
  'b': tf.bool,
  'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,
  'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,
  'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,
  'c64': tf.complex64, 'c128': tf.complex128,
  'i*': int, 'f*': float, 'c*': complex}

tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}

def make_tf_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return tf.zeros(1, dtype=dtype)

def result_code(dtype1, dtype2):
  try:
    out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))
  except (TypeError, tf.errors.InvalidArgumentError):
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return tf_dtype_to_code[type(out)]
    else:
      return tf_dtype_to_code[out.dtype]


grid = [[result_code(dtype1, dtype2)
         for dtype2 in tf_dtypes.values()]
        for dtype1 in tf_dtypes.values()]
table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - u8 - -
u16 - - u16 - - - - - - - - - - - - u16 - -
u32 - - - u32 - - - - - - - - - - - u32 - -
u64 - - - - u64 - - - - - - - - - - u64 - -
i8 - - - - - i8 - - - - - - - - - i8 - -
i16 - - - - - - i16 - - - - - - - - i16 - -
i32 - - - - - - - i32 - - - - - - - i32 - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - bf16 bf16 -
f16 - - - - - - - - - - f16 - - - - f16 f16 -
f32 - - - - - - - - - - - f32 - - - f32 f32 -
f64 - - - - - - - - - - - - f64 - - f64 f64 -
c64 - - - - - - - - - - - - - c64 - c64 c64 c64
c128 - - - - - - - - - - - - - - c128 c128 c128 c128
i* - - - - - - - i32 - - - - - - - i32 - -
f* - - - - - - - - - - - f32 - - - f32 f32 -
c* - - - - - - - - - - - - - - c128 c128 c128 c128

PyTorch 类型提升#

请注意,torch 不包含大于 uint8 的无符号整数类型。除了这一点以及一些关于使用标量/弱类型提升的详细信息外,该表与 jax.numpy 使用的表非常接近。

隐藏代码单元格来源
# @title
import torch
import pandas as pd
from IPython import display

torch_dtypes = {
  'b': torch.bool,
  'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',
  'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,
  'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,
  'c64': torch.complex64, 'c128': torch.complex128,
  'i*': int, 'f*': float, 'c*': complex}

torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}

def make_torch_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return torch.zeros(1, dtype=dtype)

def torch_result_code(dtype1, dtype2):
  try:
    out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return torch_dtype_to_code[type(out)]
    else:
      return torch_dtype_to_code[out.dtype]


grid = [[torch_result_code(dtype1, dtype2)
         for dtype2 in torch_dtypes.values()]
        for dtype1 in torch_dtypes.values()]
table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
u8 u8 u8 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f32 c64
u16 - - - - - - - - - - - - - - - - - -
u32 - - - - - - - - - - - - - - - - - -
u64 - - - - - - - - - - - - - - - - - -
i8 i8 i16 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f32 c64
i16 i16 i16 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f32 c64
i32 i32 i32 - - - i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f32 c64
i64 i64 i64 - - - i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
bf16 bf16 bf16 - - - bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 - - - f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 - - - f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 - - - f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 - - - c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
f* f32 f32 - - - f32 f32 f32 f32 bf16 f16 f32 f64 c64 c128 f32 f64 c64
c* c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c128

JAX 类型提升:jax.numpy#

jax.numpy 遵循 https://jax.ac.cn/en/latest/type_promotion.html 中规定的类型提升规则。在这里,我们使用 i*f*c* 来表示 Python 标量和弱类型数组。

隐藏代码单元格来源
# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f* c*
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 u16 f* c*
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 u32 f* c*
u64 u64 u64 u64 u64 u64 f* f* f* f* bf16 f16 f32 f64 c64 c128 u64 f* c*
i8 i8 i16 i32 i64 f* i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f* c*
i16 i16 i16 i32 i64 f* i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f* c*
i32 i32 i32 i32 i64 f* i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f* c*
i64 i64 i64 i64 i64 f* i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f* c*
bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 f16 f16 f16 f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i* u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
f* f* f* f* f* f* f* f* f* f* bf16 f16 f32 f64 c64 c128 f* f* c*
c* c* c* c* c* c* c* c* c* c* c64 c64 c64 c128 c64 c128 c* c* c*

JAX 类型提升:jax.lax#

jax.lax 是一个更底层的接口,不会进行任何隐式类型提升。这里我们使用 i*f*c* 来表示 Python 标量和弱类型数组。

隐藏代码单元格来源
# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jax.lax.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - - - -
u16 - - u16 - - - - - - - - - - - - - - -
u32 - - - u32 - - - - - - - - - - - - - -
u64 - - - - u64 - - - - - - - - - - - - -
i8 - - - - - i8 - - - - - - - - - - - -
i16 - - - - - - i16 - - - - - - - - - - -
i32 - - - - - - - i32 - - - - - - - - - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - - - -
f16 - - - - - - - - - - f16 - - - - - - -
f32 - - - - - - - - - - - f32 - - - - - -
f64 - - - - - - - - - - - - f64 - - - f64 -
c64 - - - - - - - - - - - - - c64 - - - -
c128 - - - - - - - - - - - - - - c128 - - c128
i* - - - - - - - - i64 - - - - - - i* - -
f* - - - - - - - - - - - - f64 - - - f* -
c* - - - - - - - - - - - - - - c128 - - c*