JAX 类型提升语义的设计#

Open in Colab Open in Kaggle

Jake VanderPlas,2021 年 12 月

在任何数值计算库的设计中,都面临着一个挑战,即如何处理不同类型的值之间的运算。本文档概述了 JAX 使用的提升语义背后的思考过程,总结在 JAX 类型提升语义 中。

JAX 类型提升的目标#

JAX 的数值计算 API 借鉴了 NumPy 的 API,并进行了一些增强,包括能够针对 GPU 和 TPU 等加速器。这使得采用 NumPy 的类型提升系统对于 JAX 用户来说是不利的:NumPy 的类型提升规则过分偏向于 64 位输出,这对于加速器上的计算来说是有问题的。诸如 GPU 和 TPU 等设备通常会因为使用 64 位浮点数类型而导致明显的性能下降,并且在某些情况下根本不支持原生 64 位浮点数类型。

这种有问题的类型提升语义的一个简单示例可以在 32 位整数和浮点数之间的二元运算中看到。

import numpy as np
np.dtype(np.int32(1) + np.float32(1))
dtype('float64')

NumPy 倾向于生成 64 位值的特性是使用 NumPy API 进行加速器计算的一个长期存在的问题,目前还没有很好的解决方案。出于这个原因,JAX 试图重新思考 NumPy 风格的类型提升,并考虑到加速器。

回顾:表格和格#

在我们深入细节之前,让我们花一点时间退一步,思考一下如何思考类型提升的问题。考虑 Python 中内置数值类型之间的算术运算,即类型为intfloatcomplex的类型。只需几行代码,我们就可以生成 Python 用于这些类型值之间加法的类型提升表。

import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
             index=[name(t) for t in types], columns=[name(t) for t in types])
int float complex
int int float complex
float float float complex
complex complex complex complex

此表列出了 Python 的数值类型提升行为,但事实证明,存在一种更紧凑的补充表示法:表示法,其中任意两个节点之间的上确界是它们提升到的类型。Python 提升表的格表示法要简单得多。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)
../_images/818a3cf499d15c3be1d4c116db142da0418c174873f21e1ffcde679c6058f918.png

此格是对上述提升表中信息的紧凑编码。您可以通过跟踪图到两个节点的第一个公共子节点(包括节点本身)来找到两个输入的类型提升结果;在数学上,这个公共子节点被称为格上的上确界最小上界并集;在这里,我们将此操作称为并集

从概念上讲,箭头表示允许在源和目标之间进行隐式类型提升:例如,允许从整数到浮点数的隐式提升,但不允许从浮点数到整数的隐式提升。

请记住,通常并非每个有向无环图 (DAG) 都能满足格的属性。格要求每个节点对都存在唯一的最小上界;因此,例如,以下两个 DAG 不是格。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 2))

lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])

lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);
../_images/a0acbd07f9486d95c10a36c11301d528fb7e65d671d622226151c431b3e36c62.png

左侧 DAG 不是格,因为节点BC不存在上界;右侧 DAG 出于两个原因失败:首先,节点CD不存在上界,并且对于节点AB,最小上界不能唯一确定:CD都是候选者,但它们是不可比较的。

类型提升格的属性#

用格来指定类型提升确保了许多有用的属性。用\(\vee\)运算符表示格上的并集,我们有

存在性:根据定义,格要求每对元素都存在唯一的格并集:\(\forall (a, b): \exists !(a \vee b)\)

交换律:格并集是可交换的:\(\forall (a, b): a\vee b = b \vee a\)

结合律:格并集是结合的:\(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\)

另一方面,这些属性意味着对它们可以表示的类型提升系统有限制;特别是并非所有类型提升表都可以用格表示。NumPy 的完整类型提升表就是一个现成的例子;这可以通过反例快速证明:这里有三种标量类型,它们在 NumPy 中的提升行为是非结合的。

import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c)))
float32
float16

这样的结果可能会让用户感到意外:我们通常期望数学表达式映射到数学概念,因此,例如,a + b + c应该等价于c + b + ax * (y + z)应该等价于x * y + x * z。如果类型提升是非结合的或非交换的,则这些属性将不再适用。

此外,与基于表格的系统相比,基于格的类型提升系统更容易概念化和理解。例如,JAX 识别 18 种不同的类型:由 18 个节点和它们之间稀疏且动机良好的连接组成的提升格比包含 324 个条目的表格更容易记住。

出于这个原因,我们选择对 JAX 使用基于格的类型提升系统。

类别内的类型提升#

数值计算库通常提供的不仅仅是intfloatcomplex;在每个类别中,都有各种可能的精度,由数值表示中使用的位数表示。我们将在此处考虑的类别为

  • 无符号整数,包括uint8uint16uint32uint64(我们将简称为u8u16u32u64

  • 有符号整数,包括int8int16int32int64(我们将简称为i8i16i32i64

  • 浮点数,包括float16float32float64(我们将简称为f16f32f64

  • 复数浮点数,包括complex64complex128(我们将简称为c64c128

NumPy 的类型提升语义**在**这四个类别中的每个类别中都相对简单:类型的有序层次结构直接转换为四个单独的格,表示类别内类型提升规则。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/2d8495bcb006c34b42eeb4f3e0c6530fdef0bd7364c56184993925f0cf157abc.png

就 JAX 试图避免的提升到 64 位的值而言,这些在每个类型类别内的同类提升语义是没有问题的:产生 64 位输出的唯一方法是具有 64 位输入。

输入 Python 标量#

现在让我们考虑一下 Python 标量如何融入其中。

在 NumPy 中,提升行为取决于输入是数组还是标量。例如,当对两个标量进行运算时,将应用正常的提升规则。

x = np.int8(0)  # int8 scalar
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int64')

这里,Python 值1被视为int64,并且简单的类别内规则导致int64结果。

但是,在 Python 标量和 NumPy 数组之间的运算中,标量会推迟到数组的 dtype。例如

x = np.zeros(1, dtype='int8')  # int8 array
y = 1  # Python int = int64 scalar
(x + y).dtype
dtype('int8')

这里,int64标量的位宽被忽略,推迟到数组的位宽。

这里还有一个细节:当 NumPy 类型提升涉及标量时,输出 dtype 取决于值:如果 Python 标量对于给定的 dtype 太大,则会将其提升到兼容的类型。

x = np.zeros(1, dtype='int8')  # int8 array
y = 1000  # int64 scalar
(x + y).dtype
dtype('int16')

出于 JAX 的目的,依赖于值的提升是一个无法接受的方案,因为 JIT 编译和其他转换的本质是在没有引用其值的情况下对数据的抽象表示进行操作。

忽略依赖于值的效应,NumPy 类型提升的有符号整数分支可以用以下格表示,其中我们将使用*标记标量 dtype。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
  'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3);
../_images/7e8c3295e403209560d8e142c5c830d79456a4e6d207dd1a7e4d15b55c56006b.png

uintfloatcomplex格中也存在类似的模式。

为了简单起见,让我们将每类标量类型折叠成一个节点,分别表示为u*i*f*c*。我们的一组类别内格现在可以这样表示。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
  'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
  'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
  'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
  'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/0fbe0c20cd350821e64f3742aa7864ec729565572b136950042095881672fdb9.png

在某种意义上,将标量放在左侧是一个奇怪的选择:标量类型可能包含任何宽度的值,但在与给定类型的数组交互时,提升结果会推迟到数组类型。这样做的好处是,当您执行诸如x + 2(对于数组x)之类的操作时,无论x的宽度如何,x的类型都会传递到结果中。

for dtype in [np.int8, np.int16, np.int32, np.int64]:
  x = np.arange(10, dtype=dtype)
  assert (x + 2).dtype == dtype

此行为为我们用于标量值的*表示法提供了动机:*让人联想到可以取任何所需值的通配符。

这些语义的好处是,您可以使用简洁的 Python 代码轻松地表达操作序列,而无需显式地将标量转换为适当的类型。想象一下,如果您不必编写以下内容:

3 * (x + 1) ** 2

您必须编写以下内容:

np.int32(3) * (x + np.int32(1)) ** np.int32(2)

虽然它是显式的,但数值代码将变得难以阅读或编写。使用上面描述的标量提升语义,给定类型为int32的数组x,第二个语句中的类型在第一个语句中是隐式的。

组合格#

回顾一下,我们开始讨论时介绍了表示 Python 中类型提升的格:int -> float -> complex。让我们将其改写为i* -> f* -> c*,并且让我们进一步允许i*包含u*(毕竟,Python 中没有无符号整数标量类型)。

将这些都放在一起,我们得到以下表示 Python 标量和 NumPy 数组之间类型提升的部分格

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/796586be87180b0de3171d39763f2d33a80a641b72d82c00f0c0e352f754f201.png

请注意,这(目前)不是一个真正的格:许多节点对不存在连接。但是,我们可以将其视为一个部分格,其中某些节点对没有定义的提升行为,并且这个部分格的已定义部分确实正确地描述了 NumPy 的数组提升行为(忽略上面提到的依赖于值的语义)。

这为我们提供了一个很好的框架,通过它我们可以考虑通过在这个图上添加连接来填充这些未定义的提升规则。但是要添加哪些连接呢?广义地说,我们希望任何额外的连接都满足一些属性

  1. 提升应满足交换律和结合律:换句话说,该图应保持为(部分)格。

  2. 提升绝不允许丢弃数据的整个组件:例如,我们绝不应将complex提升为float,因为它会丢弃任何虚部。

  3. 提升绝不应导致未处理的溢出。例如,最大的uint32是最大的int32的两倍,因此我们不应隐式地将uint32提升为int32

  4. 在可能的情况下,提升应避免精度损失。例如,一个int64值可能具有 64 位尾数,因此将int64提升为float64表示可能存在精度损失。但是,最大的可表示 float64 大于最大的可表示 int64,因此在这种情况下标准 #3 仍然得到满足。

  5. 在可能的情况下,二元提升应避免导致比输入更宽的类型。这是为了确保 JAX 的隐式提升对基于加速器的流程友好,在这些流程中,用户通常希望将类型限制为 32 位(或在某些情况下为 16 位)值。

格上的每个新连接都为用户带来一定程度的便利(一组无需显式转换即可交互的新类型),但如果违反了上述任何标准,则便利性可能会变得过于昂贵。开发一个完整的提升格涉及在便利性和成本之间取得平衡。

混合提升:浮点数和复数#

让我们从也许最简单的情况开始,即浮点数和复数值之间的提升。

复数由一对浮点数组成,因此我们之间存在自然的提升路径:在保持实部宽度的情况下将浮点数转换为复数。就我们的部分格表示而言,它看起来像这样

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/bf87909b2344aed80590d1c6d91585a02b25898ac217526cb49948d91205318f.png

事实证明,这恰好表示了 NumPy 在混合浮点数/复数类型提升中使用的语义。

混合提升:有符号和无符号整数#

对于下一个案例,让我们考虑一些更困难的事情:有符号整数和无符号整数之间的提升。例如,在将uint8提升为有符号整数时,我们需要多少位?

乍一看,您可能会认为将uint8提升为int8很自然;但是最大的uint8数字在int8中不可表示。出于这个原因,将无符号整数提升为位数加倍的整数更有意义;这种提升行为可以通过向提升格添加以下连接来表示

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/3be7e17889458ac823bb5dacf31525c0d96578c6854962f45dcc60ec987a30bd.png

同样,此处添加的连接正是 NumPy 用于混合整数提升的提升语义。

如何处理uint64#

混合有符号/无符号整数提升的方法遗漏了一种类型:uint64。按照上述模式,涉及uint64的混合整数运算的结果应为int128,但这不是标准的可用 dtype。

NumPy 在此处的选择是提升为float64

(np.uint64(1) + np.int64(1)).dtype
dtype('float64')

但是,这可能是一种令人惊讶的约定:它是整数类型提升不产生整数的唯一情况。目前,我们将uint64提升保持为未定义,稍后我们将返回到它。

混合提升:整数和浮点数#

在将整数提升为浮点数时,我们可能会从与有符号整数和无符号整数之间混合提升相同的思考过程开始。一个 16 位有符号或无符号整数不能由一个 16 位浮点数以完全精度表示,后者只有 10 位尾数。因此,将整数提升为由两倍位数表示的浮点数可能是有意义的

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/8b3247e8189fbfad46a7e5583b636866fc45576e07c9bfd904457926306299d1.png

这实际上是 NumPy 类型提升所做的,但这样做破坏了图的格属性:例如,对{i8, u8}不再具有唯一的最小上界:可能性是i16f16,它们在图上是不可排序的。事实证明,这是上面突出显示的 NumPy 非结合性类型提升的根源。

我们能否想出 NumPy 提升规则的修改,使其满足格属性,同时也为混合类型提升提供合理的 结果?我们可以在此处采取一些方法。

选项 0:将整数/浮点数混合精度保留为未定义#

为了使行为完全可预测(以牺牲用户便利性为代价),一个合理的做法是将超出 Python 标量的任何混合整数/浮点数提升保留为未定义,停止使用上一节中的部分格。缺点是需要用户在整数和浮点数之间进行操作时进行显式类型转换。

选项 1:避免所有精度损失#

如果我们的重点是避免任何代价的精度损失,我们可以通过通过其现有有符号整数路径将无符号整数提升为浮点数来恢复格属性

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/1eda89d008a8c6dadf926229bf9f2245722006c5bc1c42961c555a2595c95117.png

这种方法的缺点是它仍然使int64uint64提升保持未定义,因为没有标准浮点数类型具有足够的尾数位来表示其全部值的范围。我们可以放宽精度约束并通过从i64->f64u64->f64绘制连接来完成格,但是这些链接将违背此提升方案的动机。

第二个缺点是这个格使得很难找到一个合理的位置来插入bfloat16(见下文)同时保持格属性。

这种方法的第三个缺点,对于 JAX 的加速器后端来说更为重要,是某些操作会导致比必要类型宽得多的类型;例如uint16float16之间的混合操作将一直提升到float64,这不是理想的。

选项 2:避免大多数比必要更宽的提升#

为了解决对更宽类型的无必要提升,我们可以接受在整数/浮点数提升中存在一些精度损失的可能性,将有符号整数提升为相同宽度的浮点数

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/f41cee38a476bf636be901e7f64a5dc3687002f9d12532ab706b9077d602b175.png

虽然这确实允许在整数和浮点数之间进行精度损失的提升,但这些提升不会错误地表示结果的幅度:尽管浮点数尾数不够宽以表示所有值,但指数足够宽以近似它们。

这种方法还允许从int64float64的自然提升路径,尽管uint64在此方案中仍然不可提升。也就是说,从u64f64的连接可以比以前更容易地在这里得到证明。

这种提升方案仍然导致一些比必要的提升路径更宽;例如float32uint32之间的操作导致float64。此外,这个格使得很难找到一个合理的位置来插入bfloat16(见下文)同时保持格属性。

选项 3:避免所有比必要更宽的提升#

如果我们愿意从根本上改变我们围绕整数和浮点数提升的思维方式,我们可以避免所有非理想的 64 位提升。就像标量始终遵循数组类型的宽度一样,我们可以使整数始终遵循浮点类型的宽度

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],
  'c64': [3, 3], 'c128': [4, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/d3f5e5be4354238a60698cb4f228d4e1f75a665577343c36b2c1ade1207783a0.png

这涉及一个小小的障眼法:之前我们使用f*来指代标量类型。在这个格中,f*可能应用于混合计算的数组输出。与其将f*视为标量,不如将其视为一种具有不同提升规则的特殊float值:在 JAX 中,我们将其称为弱浮点数;见下文。

这种方法的优点是,在无符号整数之外,它避免了所有比必要的提升更宽:您永远不会在没有 64 位输入的情况下获得 f64 输出,并且您永远不会在没有 32 位输入的情况下获得 f32 输出:这导致在加速器上工作的便利语义,同时避免了意外的 64 位值。

赋予浮点数类型优先级的特性类似于 PyTorch 的类型提升行为。这个格也恰好生成了一个提升表,该表非常类似于 JAX 最初的特定于情况的类型提升方案,该方案不是基于格的,但具有赋予浮点数类型优先级的特性。

此外,这个格提供了一个自然的位置来插入bfloat16,而无需在bf16f16之间强加排序。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],
  'c64': [3.5, 3], 'c128': [4.5, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
../_images/aa73688b580b02776fce218d6efe58792ae3b0976160a4b0c130b797780578af.png

这一点很重要,因为f16bf16不可比较,因为它们以不同的方式利用其位:bf16以较低的精度表示更大的范围,而f16以较高的精度表示较小的范围。

然而,这些优势也带来了一些权衡。

  • 混合浮点数/整数提升非常容易导致精度损失:例如,int64(最大值为\(9.2 \times 10^{18}\))可以提升为float16(最大值为\(6.5 \times 10^4\)),这意味着大多数可表示的值将变为inf

  • 如上所述,f*不再可以被认为是“标量类型”,而是一种不同类型的 float64。用 JAX 的术语来说,这被称为弱类型,因为它表示为 64 位,但在与其他值的提升中仅弱持有此位宽。

请注意,这种方法仍然没有解决uint64的提升问题,尽管将格闭合并连接u64f*可能是合理的。

JAX 中的类型提升#

在设计 JAX 的类型提升语义时,我们牢记了其中许多想法,并重点依赖了几件事。

  1. 我们选择将 JAX 的类型提升语义约束为满足格属性的图:这是为了确保结合律和交换律,同时也为了允许语义以 DAG 的形式进行紧凑描述,而不是需要一个大型表格。

  2. 我们倾向于避免无意中提升到更宽类型的语义,特别是在涉及浮点值时,以便使加速器上的计算受益。

  3. 如果需要维护 (1) 和 (2),我们不介意在混合类型提升中接受潜在的精度损失(但不是幅度损失)。

考虑到这一点,JAX 采用了选项 3。或者更确切地说,是选项 3 的一个略微修改的版本,它在u64f*之间建立了连接,以便创建一个真正的格。为了清晰起见,重新排列节点后,JAX 的类型提升格如下所示。

隐藏代码单元格源代码
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],
  'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],
  'c64': [7.5, 1.5], 'c128': [8.5, 1.5],
}
fig, ax = plt.subplots(figsize=(10, 4))
ax.set_ylim(-0.5, 2)
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
# ax.patches[12].set_linestyle((0, (2, 4)))
../_images/d261add493a579484d9772634ce146f1240af3966d0845839c354417a3de2e53.png

此选择产生的行为总结在JAX 类型提升语义中。值得注意的是,除了包含更大的无符号类型(u16u32u64)以及关于标量/弱类型(i*f*c*)行为的一些细节之外,这种类型提升方案与 PyTorch 选择的方案非常接近。

对于感兴趣的人,以下附录打印了 NumPy、Tensorflow、PyTorch 和 JAX 使用的完整提升表。

附录:类型提升表示例#

以下是各种 Python 数组计算库实现的隐式类型提升表的一些示例。

NumPy 类型提升#

请注意,NumPy 不包含bfloat16数据类型,并且下表忽略了依赖于值的效应。

隐藏代码单元格源代码
# @title

import numpy as np
import pandas as pd
from IPython import display

np_dtypes = {
  'b': np.bool_,
  'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,
  'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,
  'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,
  'c64': np.complex64, 'c128': np.complex128,
  'i*': int, 'f*': float, 'c*': complex}

np_dtype_to_code = {val: key for key, val in np_dtypes.items()}

def make_np_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return np.zeros(1, dtype=dtype)

def np_result_code(dtype1, dtype2):
  try:
    out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return np_dtype_to_code[type(out)]
    else:
      return np_dtype_to_code[out.dtype.type]


grid = [[np_result_code(dtype1, dtype2)
         for dtype2 in np_dtypes.values()]
        for dtype1 in np_dtypes.values()]
table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 - f16 f32 f64 c64 c128 u8 f64 c128
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 - f32 f32 f64 c64 c128 u16 f64 c128
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 u32 f64 c128
u64 u64 u64 u64 u64 u64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 u64 f64 c128
i8 i8 i16 i32 i64 f64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i8 f64 c128
i16 i16 i16 i32 i64 f64 i16 i16 i32 i64 - f32 f32 f64 c64 c128 i16 f64 c128
i32 i32 i32 i32 i64 f64 i32 i32 i32 i64 - f64 f64 f64 c128 c128 i32 f64 c128
i64 i64 i64 i64 i64 f64 i64 i64 i64 i64 - f64 f64 f64 c128 c128 i64 f64 c128
bf16 - - - - - - - - - - - - - - - - - -
f16 f16 f16 f32 f64 f64 f16 f32 f64 f64 - f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f64 f64 f32 f32 f64 f64 - f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 - f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c128 c128 c64 c64 c128 c128 - c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 - c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 u16 u32 u64 i8 i16 i32 i64 - f16 f32 f64 c64 c128 i64 f64 c128
f* f64 f64 f64 f64 f64 f64 f64 f64 f64 - f16 f32 f64 c64 c128 f64 f64 c128
c* c128 c128 c128 c128 c128 c128 c128 c128 c128 - c64 c64 c128 c64 c128 c128 c128 c128

Tensorflow 类型提升#

Tensorflow 避免定义隐式类型提升,除了在有限情况下使用 Python 标量。该表是不对称的,因为在tf.add(x, y)中,y的类型必须可以强制转换为x的类型。

隐藏代码单元格源代码
# @title

import tensorflow as tf
import pandas as pd
from IPython import display

tf_dtypes = {
  'b': tf.bool,
  'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,
  'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,
  'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,
  'c64': tf.complex64, 'c128': tf.complex128,
  'i*': int, 'f*': float, 'c*': complex}

tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}

def make_tf_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return tf.zeros(1, dtype=dtype)

def result_code(dtype1, dtype2):
  try:
    out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))
  except (TypeError, tf.errors.InvalidArgumentError):
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return tf_dtype_to_code[type(out)]
    else:
      return tf_dtype_to_code[out.dtype]


grid = [[result_code(dtype1, dtype2)
         for dtype2 in tf_dtypes.values()]
        for dtype1 in tf_dtypes.values()]
table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - u8 - -
u16 - - u16 - - - - - - - - - - - - u16 - -
u32 - - - u32 - - - - - - - - - - - u32 - -
u64 - - - - u64 - - - - - - - - - - u64 - -
i8 - - - - - i8 - - - - - - - - - i8 - -
i16 - - - - - - i16 - - - - - - - - i16 - -
i32 - - - - - - - i32 - - - - - - - i32 - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - bf16 bf16 -
f16 - - - - - - - - - - f16 - - - - f16 f16 -
f32 - - - - - - - - - - - f32 - - - f32 f32 -
f64 - - - - - - - - - - - - f64 - - f64 f64 -
c64 - - - - - - - - - - - - - c64 - c64 c64 c64
c128 - - - - - - - - - - - - - - c128 c128 c128 c128
i* - - - - - - - i32 - - - - - - - i32 - -
f* - - - - - - - - - - - f32 - - - f32 f32 -
c* - - - - - - - - - - - - - - c128 c128 c128 c128

PyTorch 类型提升#

请注意,torch 不包含大于uint8的无符号整数类型。除了这一点以及关于标量/弱类型提升的一些细节之外,该表与jax.numpy使用的表非常接近。

隐藏代码单元格源代码
# @title
import torch
import pandas as pd
from IPython import display

torch_dtypes = {
  'b': torch.bool,
  'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',
  'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,
  'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,
  'c64': torch.complex64, 'c128': torch.complex128,
  'i*': int, 'f*': float, 'c*': complex}

torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}

def make_torch_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return torch.zeros(1, dtype=dtype)

def torch_result_code(dtype1, dtype2):
  try:
    out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if type(out) in {int, float, complex}:
      return torch_dtype_to_code[type(out)]
    else:
      return torch_dtype_to_code[out.dtype]


grid = [[torch_result_code(dtype1, dtype2)
         for dtype2 in torch_dtypes.values()]
        for dtype1 in torch_dtypes.values()]
table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
u8 u8 u8 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f32 c64
u16 - - - - - - - - - - - - - - - - - -
u32 - - - - - - - - - - - - - - - - - -
u64 - - - - - - - - - - - - - - - - - -
i8 i8 i16 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f32 c64
i16 i16 i16 - - - i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f32 c64
i32 i32 i32 - - - i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f32 c64
i64 i64 i64 - - - i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
bf16 bf16 bf16 - - - bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 - - - f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 - - - f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 - - - f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 - - - c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i64 u8 - - - i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i64 f32 c64
f* f32 f32 - - - f32 f32 f32 f32 bf16 f16 f32 f64 c64 c128 f32 f64 c64
c* c64 c64 - - - c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c128

JAX 类型提升:jax.numpy#

jax.numpy遵循 https://jax.ac.cn/en/latest/type_promotion.html 中列出的类型提升规则。这里我们使用i*f*c*来表示 Python 标量和弱类型数组。

隐藏代码单元格源代码
# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
u8 u8 u8 u16 u32 u64 i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 u8 f* c*
u16 u16 u16 u16 u32 u64 i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 u16 f* c*
u32 u32 u32 u32 u32 u64 i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 u32 f* c*
u64 u64 u64 u64 u64 u64 f* f* f* f* bf16 f16 f32 f64 c64 c128 u64 f* c*
i8 i8 i16 i32 i64 f* i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i8 f* c*
i16 i16 i16 i32 i64 f* i16 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i16 f* c*
i32 i32 i32 i32 i64 f* i32 i32 i32 i64 bf16 f16 f32 f64 c64 c128 i32 f* c*
i64 i64 i64 i64 i64 f* i64 i64 i64 i64 bf16 f16 f32 f64 c64 c128 i64 f* c*
bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 bf16 f32 f32 f64 c64 c128 bf16 bf16 c64
f16 f16 f16 f16 f16 f16 f16 f16 f16 f16 f32 f16 f32 f64 c64 c128 f16 f16 c64
f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f32 f64 c64 c128 f32 f32 c64
f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 f64 c128 c128 f64 f64 c128
c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c64 c128 c64 c128 c64 c64 c64
c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128 c128
i* i* u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
f* f* f* f* f* f* f* f* f* f* bf16 f16 f32 f64 c64 c128 f* f* c*
c* c* c* c* c* c* c* c* c* c* c64 c64 c64 c128 c64 c128 c* c* c*

JAX 类型提升:jax.lax#

jax.lax是更低级的,它不进行任何隐式提升。这里我们使用i*f*c*来表示 Python 标量和弱类型数组。

隐藏代码单元格源代码
# @title
import jax
import jax.numpy as jnp
import pandas as pd
from IPython import display
jax.config.update('jax_enable_x64', True)

jnp_dtypes = {
  'b': jnp.bool_.dtype,
  'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,
  'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,
  'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,
  'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,
  'i*': int, 'f*': float, 'c*': complex}


jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}

def make_jnp_zero(dtype):
  if dtype in {int, float, complex}:
    return dtype(0)
  else:
    return jnp.zeros((), dtype=dtype)

def jnp_result_code(dtype1, dtype2):
  try:
    out = jax.lax.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))
  except TypeError:
    return '-'
  else:
    if hasattr(out, 'aval') and out.aval.weak_type:
      return out.dtype.kind + '*'
    elif type(out) in {int, float, complex}:
      return jnp_dtype_to_code[type(out)]
    else:
      return jnp_dtype_to_code[out.dtype]

grid = [[jnp_result_code(dtype1, dtype2)
         for dtype2 in jnp_dtypes.values()]
        for dtype1 in jnp_dtypes.values()]
table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())
display.HTML(table.to_html())
b u8 u16 u32 u64 i8 i16 i32 i64 bf16 f16 f32 f64 c64 c128 i* f* c*
b - - - - - - - - - - - - - - - - - -
u8 - u8 - - - - - - - - - - - - - - - -
u16 - - u16 - - - - - - - - - - - - - - -
u32 - - - u32 - - - - - - - - - - - - - -
u64 - - - - u64 - - - - - - - - - - - - -
i8 - - - - - i8 - - - - - - - - - - - -
i16 - - - - - - i16 - - - - - - - - - - -
i32 - - - - - - - i32 - - - - - - - - - -
i64 - - - - - - - - i64 - - - - - - i64 - -
bf16 - - - - - - - - - bf16 - - - - - - - -
f16 - - - - - - - - - - f16 - - - - - - -
f32 - - - - - - - - - - - f32 - - - - - -
f64 - - - - - - - - - - - - f64 - - - f64 -
c64 - - - - - - - - - - - - - c64 - - - -
c128 - - - - - - - - - - - - - - c128 - - c128
i* - - - - - - - - i64 - - - - - - i* - -
f* - - - - - - - - - - - - f64 - - - f* -
c* - - - - - - - - - - - - - - c128 - - c*