使用 Pallas 编写 TPU 内核#

此页面重点介绍尝试在 Google TPU 上运行 Pallas 内核时很重要的细节。首先,TPU 后端仍处于实验阶段,并且只会接受 JAX NumPy 的一个子集。此外,为 TPU 编写高性能代码可能需要仔细考虑硬件的原生功能。虽然许多对硬件来说不自然的模式会被接受,但它们最终可能需要软件仿真,并且会降低计算速度。

警告

此功能仍应被视为实验性功能,因为工作仍在进行中(特别是改进错误消息方面)。

注意

虽然此处描述的所有功能都处于实验阶段,但我们仍然非常重视维护其正确性。因此,在尝试编写 TPU 内核时,看到“未实现”错误并不少见。但是,如果内核被编译器接受,则_必须_返回预期结果。

如果您看到意外输出,请将其与使用 interpret=True 传递给 pallas_call 运行的内核进行比较。如果结果出现偏差,请提交 错误报告

什么是 TPU?#

A TPUv4 board

TPU 是 Google 开发的一种硬件加速器。您可以将 TPU 视为 GPU,但专门用于机器学习工作负载。因此,它们的架构差异很大。但是,我们相信 Pallas 可以让您轻松开始编写 TPU 内核,即使您不完全了解底层硬件。话虽如此,深入了解硬件肯定会让您更容易编写高性能内核。

简而言之,TPU 和 GPU 之间的主要区别在于 TPU 是具有非常宽向量寄存器(有点像 CPU!)的顺序机器。同时,它们允许软件在后台调度某些操作,使其相对于主指令流异步执行。这包括 HBM 内存访问(不能直接发出,而必须由 DMA 子单元预取到较低级别的内存层次结构)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)。

如果您有兴趣详细了解 TPU 架构,我们建议您阅读多年来发表的一系列论文。虽然其中许多论文讨论了特定 TPU 代,但许多描述的想法也适用于后来的代。

值得注意的属性和限制#

BlockSpec 和网格迭代#

BlockSpec(参见 BlockSpec,也称为如何将输入分块)在 Pallas 中通常按预期工作——内核体的每次调用都可以访问输入的切片,并用于初始化输出的切片。

注意

并非所有块形状都受支持。在 TPU 上,仅支持秩至少为 1 的块

。此外,块形状的最后两个维度必须分别能被 8 和 128 整除,或者等于整个数组的相应维度。

Pallas TPU 内核的一个有趣方面是它们处理内存空间的方式:虽然 pallas_call 的输入通常位于 HBM(主 TPU 内存)中,但传递给内核体的引用将指向内存层次结构较低级别(VMEM 或 SMEM)中的缓冲区。这使内核体能够以非常高的速度写入和读取它们,而所有与 HBM(具有非常高的延迟)的通信都由编译器处理,并与计算重叠。

此外,与 GPU 相比,TPU 实际上是高度顺序的机器。因此,网格通常不是并行处理的,而是按字典顺序顺序处理的(尽管请参阅 多核 TPU 配置 部分以了解例外情况)。这解锁了一些有趣的功能

  • 当两个(字典序)连续的网格索引使用输入的同一切片时,将跳过第二次迭代的 HBM 传输,因为数据已可用。

  • 内核体的多次调用可以写入输出的同一切片,而不会出现任何竞争条件的风险。但是,我们确实要求写入特定切片的所有调用都是连续的。

输出的“连续”限制通常意味着网格维度的某些前缀始终改变调用需要访问的输出切片,而输出窗口对于其余后缀保持不变。

例如,在为矩阵乘法实现 Pallas TPU 内核时,通常会使用 3 维网格:前两个维度对应于沿左操作数的第一轴和第二个操作数的第二轴进行切片。第三个也是_最后一个_网格轴将平铺约简维度。对应于约简维度的网格轴必须是最后一个,因为输出窗口沿此轴不变化。然后,输出引用可以用作部分结果的累加器。

注意

对于如此低级别的内存层次结构(16MB+),VMEM 非常大,因此可以使用较大的窗口大小。而且,通常情况下,窗口大小越大,最终的硬件利用率就越好。但是,可以指定一个窗口大小,该窗口大小(以及保存溢出向量寄存器所需的存储空间)超过 VMEM 的大小。在这种情况下,您可能会看到一个低级编译器错误消息,抱怨内存不足错误。

维度排序是有意义的#

在 JAX 程序中,jax.jit 内部中间数组的排序通常对性能没有影响,因为编译器可以自由地重新排列它们。但是,由于 Pallas 旨在公开较低级别的功能,因此维度顺序会对生成的代码质量产生重大影响。

回想一下,TPU 在 2D 向量寄存器上执行大部分计算。Pallas TPU 将始终只考虑将中间数组的最后两个维度映射到这些向量寄存器维度(分别为子车道和车道)。形状为 (n, 1, 1) 的数组保证至少需要 n 个向量寄存器来表示。如果 n 变得太大,这会导致溢出,并可能由于过大的内存占用而导致 VMEM OOM 错误。但它也可能不会——低级编译器可以自由地重新排列指令以降低寄存器压力,并且实际上非常擅长此操作。尽管如此,一个好的经验法则是保持最后两个维度较大(尤其是最后一个维度),同时保持前导维度较小。

多核 TPU 配置#

在较新的 TPU 代中,芯片上的两个内核通常被抽象为单个设备。为了利用多个内核,Pallas 必须打破顺序网格执行保证,并且需要将网格轴之一并行化到内核上。这是一个选择加入的过程。为了允许这样做,pallas_call 需要一个名为 dimension_semantics 的额外参数

该参数是一个列表,其条目数与网格中轴数相同。只有 parallel 维度可以跨内核进行分区。根据经验法则,除非输出窗口不变化,否则维度是并行的。因此,dimension_semantics 始终是一定数量的 parallel 轴,后跟一定数量的 arbitrary 轴。

虽然将内核分区到 2 核 TPU 设备上通常会导致 2 倍的加速,但实际上它可能要小得多。如果内核体的不同实例具有高度不同的成本,则尤其如此。如果所有昂贵的步骤都映射到一个内核,但所有廉价的步骤都分配给另一个内核,则第二个内核将处于空闲状态,直到第一个内核完成其任务。

Pallas TPU 通常倾向于对大小为 TPU 内核数量倍数的轴进行分区,并倾向于对前导网格轴进行分区。

将操作数放置在 SMEM 中#

TPU 上的大部分计算将在向量单元上进行。尽管如此,在许多情况下,执行一些标量操作非常有用,例如,执行控制流。出于这个原因,TPU 带有一个单独的标量单元,以及一个连接到它的单独标量内存 (SMEM)。根据经验法则,任何用于执行控制流决策的数据都应放置在 SMEM 中。

SMEM 是一种低延迟内存,支持随机访问,但仅允许您使用单个指令读取和写入 32 位值(与 VMEM 事务的 4KBi 粒度相比非常小,但由于缺乏对齐要求而更加灵活!)。

当实现不以规则模式访问输入平铺的内核(例如,在编写块稀疏内核时)时,标量内存也非常有用。在 Pallas 中,可以通过将 pallas_callgrid 参数替换为具有非零 num_scalar_prefetch 参数的 PrefetchScalarGridSpecgrid_spec 来实现这一点。如果 num_scalar_prefetchn,则 pallas_call 的前 n 个参数将放置在 SMEM 中。对于这些参数,不应指定任何 BlockSpec。但是,所有后续参数的 BlockSpec 不仅会接收网格索引,还会接收前导操作数的 SMEM 引用。

注意

我们正在努力为此功能实现示例。敬请期待!

支持的数据类型#

目前 Pallas TPU 仅支持以下数据类型

  • jnp.float32

  • jnp.bfloat16

  • jnp.int*(所有精度,除了 jnp.int4

  • jnp.uint*(所有精度)

计算放置#

所有标量(即 0D)数组将存储在标量寄存器中,并且对其的操作将在标量内核上执行。所有其他操作(即使在单元素但 1D+ 数组上)将在向量内核上执行。

支持的操作#

矩阵乘法#

矩阵乘法始终以 float32 格式生成结果。如果您的输入不是 float32,我们建议使用 lax.dot 并将 preferred_element_type 设置为 jnp.float32

使用 lax.dot_general 时,可以将矩阵乘法操作数最后两个维度的转置融合到操作中,这可以提高整体内核性能。

精度控制#

Pallas TPU 降维器感知 jax.default_matmul_precision。为了获得最佳性能(以及最低精度),请使用 bfloat16。如果您关心数值精度,则可能需要将精度设置为 float32

警告

即使您将 32 位操作数传递给矩阵乘法,除非请求 float32 精度,否则它们将被四舍五入为 bfloat16

转置#

如果该值至少具有 4 个维度,则除最后两个轴之外的所有轴的任意转置都是免费的。否则,仅实现最后两个轴的转置。请注意,最后两个维度的某些转置可以融合到矩阵乘法中。

访问内存#

可以读取或更新引用的任意切片,但要受实现约束。目前,对 32 位宽的输入没有限制,但对于较窄的类型,只支持某些切片模式。始终支持在最后两个维度上分别以 8 和 128 的倍数对齐并具有长度为 8 和 128 的倍数的读写操作。

对矢量内存的读写通常发生在形状为 (8, 128) 的块上。因此,当读取或写入至少具有两个维度的引用时,当内存访问的基本偏移量具有可被平铺整除的索引,并且读取区域的大小是块大小的倍数时,可以获得最佳性能。

逐元素操作#

支持许多逐元素操作。值得注意的是,硬件通常只支持使用 32 位类型进行逐元素计算。当加载使用较低精度类型的操作数时,通常应在应用逐元素运算之前将其向上转换为 32 位类型。

值得注意的是,它们的成本可能差异很大。因此,我们概述了三类支持的操作:便宜 (🟢)、中等 (🌕) 和昂贵 (🔴)。

操作

成本

jnp.add+

🟢

jnp.sub-

🟢

jnp.mul*

🟢

/, //, %

🌕

jnp.maxjnp.min

🟢

jnp.where(选择)

🟢

jnp.abs

🟢

|, ^, &, ~

🟢

<<, >>

🟢

比较 (==,…)

🟢

类型转换 (.astype)

🟢

jnp.exp

🌕

jnp.tanh

🌕

jnp.pow

🌕

jnp.sin

🔴

jnp.cos

🔴

许多 JAX 函数都是根据其他 JAX 原语实现的,因此此列表可能并不全面。例如,jax.nn.relu 是根据比较实现的,jnp.where 也将在 Pallas 内核中运行。

数组构造器#

所有常量数组构造器都受支持 (jnp.onesjnp.zerosjnp.full)。值得注意的是,截至今天,jax.random 模块**不**兼容 Pallas。

归约#

支持求和、最大值和最小值归约,但一次仅支持一个数组轴。

对最后一个数组维度的归约通常是最慢的。对倒数第二个维度的归约更快,但仍然比对前导维度的归约慢。

广播#

广播的性能特征与归约的性能特征非常相似。始终支持并免费广播除最后两个维度之外的所有维度。沿倒数第二个维度的广播较慢,而沿最后一个维度的广播最慢。

重塑#

像往常一样,除最后两个维度之外的所有维度的重塑都受支持且免费。

当重塑可以修改数组的最后两个维度时,仅支持两种情况:(1) 将某些前导维度展平到倒数第二个维度,或 (2) 添加一个由归约刚刚删除的维度。

控制流#

TPU 后端目前对控制流的支持有限。当前支持的函数为 condfori_loopfor_loop。但是,循环原语目前在编译期间会被完全展开,因此请尝试将循环迭代次数保持在合理的小范围内。

过度使用控制流会导致低级代码生成出现重大回归,建议尝试将尽可能多的计算密集型操作压缩到单个基本块中。