流水线#
在本指南中,我们将介绍 TPU 中内存空间的工作原理,以及如何在 Pallas 中编写将内存 I/O 与计算重叠的流水线。
#@title Imports
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
TPU 及其内存空间#
TPU 及其 TensorCore 由内存空间(数组可以驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(对寄存器中的值进行计算)组成。下图是一个 TPU 的示意图,其中 x
和 y
是驻留在高带宽内存 (HBM) 中的数组
让我们更详细地讨论一下此图的组成部分
**内存空间**:TPU 具有高带宽内存 (HBM),这通常是我们所说的“设备内存”。此外还有向量内存 (VMEM),一种用于存储向量和数组值的缓存,以及标量内存 (SMEM),一种用于存储标量值的缓存。
**寄存器**:TensorCore 有两种主要的寄存器类型:向量寄存器 (VREG) 存储数组值,标量寄存器 (SREG) 存储标量值。值可以从各自的缓存(VREG 的 VMEM 和 SREG 的 SMEM)加载到内存中。
**计算单元**:TensorCore 具有标量单元、向量单元 (VPU) 和矩阵单元 (MXU),可以执行数值计算。计算单元对驻留在 SREG 和 VREG 中的值进行操作,并将输出值也输出到这些寄存器中。
为了对驻留在 HBM 中的值 x
和 y
执行矢量化计算,我们需要
将值
x
和y
复制到 VMEM 中。将值从 VMEM 加载到 VREG 中。
使用 VPU 或 MXU 执行计算,并将输出存储在 VREG 中。
将输出 VREG 中的值存储到 VMEM 中。
将 VMEM 中的输出值复制回 HBM。
让我们实现一个执行此操作的 Pallas 函数吧!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们编写了两个函数:add_matrices_kernel
和 add_matrices
。
add_matrices_kernel
使用位于 VMEM 中的 Ref
进行操作。从 VMEM Ref
加载会产生一个位于 VREG 中的值。VREG 中的值的行为类似于 jax.Array
,我们可以使用 jnp
和 jax.lax
操作对其进行操作以生成位于 VREG 中的新值。当生成我们想要返回的值时,我们将其存储在输出 VMEM Ref
中。
add_matrices
函数作用于 jax.Array
并返回一个 jax.Array
。在函数内部,我们将 x
和 y
传递给 pallas_call
。 pallas_call
负责将 x
和 y
复制到 VMEM,并为内核操作的 VMEM 缓冲区分配内存(包括分配输出 VMEM 缓冲区 z_vmem_ref
)。内核函数运行完成后,pallas_call
还会将 z_vmem_ref
中的值复制到 HBM,从而产生输出 jax.Array
。
使用 VMEM/SMEM 的约束#
Pallas 公开了对 VMEM 和 SMEM 等较低级别内存空间的访问,但编写利用它们的内核会增加一些注意事项。
内存容量。VMEM 和 SMEM 非常小!v4 TPU 上的 VMEM 只有 16MiB,而 SMEM 的范围在几十到几百 KiB 之间。如果我们的数组太大,我们甚至无法将它们全部放入 VMEM 中。作为参考,一个
f32[2048, 2048]
数组是 16MiB,因此我们上面的内核无法扩展到中等大小以上的数组。内存带宽。与大多数计算指令相比,复制到/从 HBM 和 VMEM 之间需要很长时间。上面的
add_matrices
函数可能会花费更多时间在 HBM 和 VMEM 之间复制,而不是实际执行加法运算本身。
考虑到这两个约束,我们将不得不重新考虑从 TPU 中获得性能的策略。
入门:流水线#
对我们的计算进行流水线处理提供了一种同时处理内存容量和带宽约束的方法。流水线处理是什么意思呢?
目标是:在复制到/从 HBM 和 VMEM 的同时,并行利用我们的计算单元。表面上这很困难,因为在我们上面的程序中,我们在开始使用 x
和 y
进行任何计算之前,先复制了所有的 x
和 y
,从而在复制和计算之间产生依赖关系。
但是,如果我们可以将我们的计算分成几个子计算(例如,当我们添加两个矩阵时,我们可以将其表示为将原始矩阵的“块”加在一起),我们现在可以将其中一个子计算的复制与另一个子计算的计算重叠。让我们来看一个简单的例子。
假设我们将数组 x
和 y
分成 x1, x2
和 y1, y2
(例如,沿前导轴拆分,每个输入产生两个 (256, 512)
数组)。我们现在可以执行以下流水线计算。
将
x1
和y1
复制到 VMEM 中。开始将
x2
和y2
复制到 VMEM 中。将
x1, y1
从 VMEM 加载到 VREG 中。使用计算单元执行
z1 = x1 + y1
。将
z1
存储到 VMEM 中。开始将
z1
从 VMEM 复制回 HBM。等待
x2, y2
被复制到 VMEM 中。将
x2, y2
从 VMEM 加载到 VREG 中。使用计算单元执行
z2 = x2 + y2
。将
z2
存储到 VMEM 中。等待
z1
被复制到 HBM 中。开始将
z2
从 VMEM 复制回 HBM。等待
z2
被复制到 HBM 中。
在任何时候进行计算时,我们都在异步复制某些内容。这意味着一些用于复制的时间没有浪费。
确定流水线计算效率的两个最重要的数字是 a) 我们需要执行多少浮点运算 (FLOP) 和 b) 我们需要复制多少字节来执行该计算。这两个数字的比率(FLOP/内存使用量)称为操作的算术强度,并决定我们的流水线是计算受限还是内存受限。
Pallas 中的流水线#
如何在 Pallas 中实现像上面那样的流水线?这似乎是一系列复杂异步数据操作和执行内核,手动实现起来很麻烦。别担心!Pallas 提供了一个 API,可以通过 grid
和 BlockSpec
来表达流水线,而无需太多样板代码。
请注意,在上面的流水线示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是在不同的输入上。 jax.experimental.pallas.pallas_call()
提供了一种通过使用 grid
参数多次执行内核的方法。请参阅 grid,也称为循环中的内核。
我们还使用 jax.experimental.pallas.BlockSpec
来指定如何构造每个内核调用的输入。请参阅 BlockSpec,也称为如何将输入分块。
在上面的流水线示例中,我们有 (512, 512)
形状的数组,并将其沿前导维度分成两个 (256, 512)
形状的数组。在这个流水线中,我们的 BlockSpec.block_shape
将是 (256, 512)
。在第一次迭代中,我们希望选择 x1
,在第二次迭代中,我们希望使用 x2
。这可以通过以下 index_map
表示。
def x_index_map(i):
return (i, 0)
然后我们将构建 BlockSpec
。
block_spec = pl.BlockSpec((256, 512), x_index_map)
y
和 z
的 BlockSpec
将与 x
的相同。
整合在一起#
我们通过 grid
、in_specs
和 out_specs
将这些参数传递给 pallas_call
(in_specs
对应于位置参数的元组,out_specs
对应于输出)。
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,)
)(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们只在原始函数中添加了一小段代码来添加自动流水线,但 BlockSpec
和 grid
做了很多繁重的工作!
它是如何工作的呢?好吧,BlockSpec
提供了足够的信息来开始从 HBM 到 VMEM 的输入块预取。例如,如果我们正在开始 grid
的第 i
次迭代,我们可以将 i + 1
传递给 index_map
函数以获取下一迭代所需的块。然后我们可以为这些块启动异步复制。类似地,对于输出,我们可以在开始当前迭代的输出复制之前等待前一次迭代的输出被复制。
参数化流水线#
通常会在内核中参数化块形状。块大小可能是优化 Pallas 内核性能时需要调整的最重要参数!它们使我们能够控制流水线(例如,选择较小的块会增加流水线循环的迭代次数,其中每次迭代的工作量更少)。
此外,我们还可以沿第二维度分割输入和输出(我们现在只沿第一个维度分割)。让我们编写一个更通用的内核来处理这两个特性。
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
处理归约#
如何使用 pallas_call
实现类似 jnp.sum
的操作?具体来说,我们希望跨归约维度进行流水线处理。
以将一个 (8, 512, 512)
形状的数组归约为 (512, 512)
形状的数组为例。
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
要使用pallas_call
实现这一点,我们可以使用大小为(8,)
的网格,并在每次迭代i
中将x[i]
加载到VMEM中。然后,我们可以将x[i]
添加到输出VMEM缓冲区。让我们首先简单地实现这一点。
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
请注意我们是如何设置BlockSpec
的:我们将整个(512, 512)
维度加载到VMEM中(没有流水线),但在index_map
中每次迭代选择x
的第i
个维度。我们在块形状中为该维度使用None
,这表示我们正在从x
中选择一个单一维度,我们希望在内核中将其压缩掉。因此,x_ref
在VMEM中也具有(512, 512)
的形状。
out_spec
使用lambda i: (0, 0)
作为其index_map
,表示o_ref
在流水线过程中保持不变。这意味着我们可以在每次迭代中通过读取和写入它来更新其值。或者可以吗?实际上有一个问题:o_ref
最初是垃圾,这意味着我们将累积到垃圾中。这将导致整个函数输出不正确的值!
因此,无论何时在内核中进行归约,我们都需要确保初始化存储归约值的Ref
。当我们在第0次迭代时,我们可以通过有条件地将值写入out_ref
来实现这一点。我们可以使用辅助函数pl.when
来实现这一点,它是一个围绕jax.lax.cond
的便捷包装器,以及pl.program_id
,它查询我们在网格轴中的哪次迭代。
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
此sum
函数现在输出正确的值!
关于Pallas中归约的最后一点需要注意的是,它们必须在网格的最次要(最右)维度上完成(我们在上面的示例中网格是一维的,因此我们正在对其最次要维度进行归约)。这是因为Pallas使用BlockSpec
、grid
和内核函数生成的流水线不会从HBM读取输出。一旦你将输出值写回HBM,你就无法再次访问它。因此,你不能跨任何重新访问的网格维度进行归约,因此所有归约都需要发生在最右边的维度上。
Megacore配置下的TPU#
某些TPU芯片有两个TensorCore,但对JAX用户显示为一个设备。这称为“megacore”。单独的TensorCore拥有各自独立的VMEM、VREG、SMEM、SREG和计算单元,但共享HBM。
从概念上讲,Megacore中的TPU的行为类似于非常简单的GPU,即它们只有两个线程。我们如何修改我们的内核以同时利用这两个TensorCore?
基本思想是,如果我们的计算中存在令人尴尬的并行维度,我们可以将这些维度跨TensorCore分割。我们可以通过向pallas_call
提供一个名为dimension_semantics
的注释来指示哪些维度是可并行化的。
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
应该是一个与grid
长度相同的元组,其中每个条目都是"parallel"
或"arbitrary"
。"parallel"
指示Pallas对应于该维度的for循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary"
指示Pallas不能对该网格维度做出任何假设,因此它不能被并行化。
通过指定dimension_semantics
,我们现在可以在每个TensorCore上同时执行内核。Pallas将自动处理网格的分割。
请注意,Megacore目前仅在TPU
v4
和TPUv5p
上可用。在其他平台上提供dimension_semantics
注释是一个无操作,但不指定它会导致仅使用一个TensorCore(即使有多个可用)。
结论#
在本指南中,我们介绍了如何使用pallas_call
、grid
和BlockSpec
来表达TPU流水线。我们介绍了如何通过多维网格表达嵌套循环,以及如何通过在归约开始时初始化累加器来处理归约。我们还学习了如何通过向内核添加注释来处理Megacore。
留给读者的练习
尝试实现一个
sum
内核,该内核也对其他维度进行流水线处理。为
add
内核和sum
内核添加megacore支持。