Pallas 快速入门#
Pallas 是 JAX 的一个扩展,它允许为 GPU 和 TPU 编写自定义内核。Pallas 允许您使用相同的 JAX 函数和 API,但在一个更低的抽象级别上运行。
具体来说,Pallas 要求用户考虑内存访问以及如何在硬件加速器中的多个计算单元之间划分计算。在 GPU 上,Pallas 降低到 Triton,在 TPU 上,Pallas 降低到 Mosaic。
让我们深入研究一些示例。
注意:Pallas 仍然是一个实验性 API,您可能会因更改而中断!
Pallas 中的 Hello World#
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
我们首先用 Pallas 编写“hello world”,一个添加两个向量的内核。
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
类型
让我们稍微剖析一下这个函数。与您可能编写的大多数 JAX 函数不同,它不接受 jax.Array
作为输入,也不返回任何值。相反,它接受 Ref
对象作为输入,这些对象表示内存中的可变缓冲区。请注意,我们也没有任何输出,但我们得到了一个 o_ref
,它对应于所需的输出。
从 Ref
读取
在主体中,我们首先从 x_ref
和 y_ref
读取,由 [...]
指示(省略号表示我们正在读取整个 Ref
;或者我们也可以使用 x_ref[:]
)。像这样从 Ref
读取会返回一个 jax.Array
。
写入 Ref
然后我们将 x + y
写入 o_ref
。JAX 历史上不支持可变性——jax.Array
是不可变的!Ref
是一种新的(实验性的)类型,它允许在某些情况下进行可变性操作。我们可以将写入 Ref
解释为改变其底层缓冲区。
因此,我们编写了一个我们称之为“内核”的东西,我们将其定义为将在加速器上作为原子执行单元运行的程序,而无需与主机进行任何交互。我们如何从 JAX 计算中调用它?我们使用 pallas_call
高阶函数。
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
将 Pallas 内核函数提升为一个可以作为更大的 JAX 程序的一部分调用的操作。但是,要做到这一点,它还需要一些额外的细节。这里我们指定 out_shape
,这是一个具有 .shape
和 .dtype
(或其列表)的对象。out_shape
确定我们 add_vector_kernel
中的 o_ref
的形状/数据类型。
pallas_call
返回一个接受和返回 jax.Array
的函数。
这里实际发生了什么?
到目前为止,我们已经描述了如何考虑 Pallas 内核,但我们实际上完成的是编写一个非常接近计算单元执行的函数,因为值被加载到内存层次结构的最内层(最快)部分。
在 GPU 上,x_ref
对应于高带宽内存 (HBM) 中的一个值,当我们执行 x_ref[...]
时,我们将该值从 HBM 复制到静态 RAM (SRAM) 中(一般来说,这是一个代价高昂的操作!)。然后我们使用 GPU 矢量计算来执行加法,然后将 SRAM 中的结果值复制回 HBM。
在 TPU 上,我们做的事情略有不同。在内核执行之前,我们将值从 HBM 获取到 SRAM 中。x_ref
因此对应于 SRAM 中的一个值,当我们执行 x_ref[...]
时,我们将该值从 SRAM 复制到一个寄存器中。然后我们使用 TPU 矢量计算来执行加法,然后将结果值复制回 SRAM 中。内核执行后,SRAM 值被复制回 HBM 中。
我们正在编写特定于后端的 Pallas 指南。即将推出!
Pallas 编程模型#
在我们的“hello world”示例中,我们编写了一个非常简单的内核。它利用了我们的 8 个大小的数组可以舒适地容纳在硬件加速器的 SRAM 中的事实。在大多数实际应用中,情况并非如此!
编写 Pallas 内核的一部分是思考如何获取存在于高带宽内存(HBM,也称为 DRAM)中的大型数组,并表达对可以容纳在 SRAM 中的这些数组的“块”进行操作的计算。
网格示例#
要自动“分割”输入和输出,您需要为 pallas_call
提供 grid
和 BlockSpec
。
grid
是一个整数元组(例如 ()
,(2, 3, 4)
或 (8,)
),它指定一个迭代空间。例如,网格 (4, 5)
将有 20 个元素:(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)
。我们为每个元素运行一次内核函数,这是一种单程序多数据 (SPMD) 编程风格。
二维网格
当我们向 pallas_call
提供 grid
时,内核的执行次数与 prod(grid)
一样多。这些调用中的每一个都被称为“程序”。要访问内核当前正在执行哪个程序(即网格的哪个元素),我们使用 program_id(axis=...)
。例如,对于调用 (1, 2)
,program_id(axis=0)
返回 1
,program_id(axis=1)
返回 2
。
这是一个使用 grid
和 program_id
的内核示例。
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
现在我们使用带有额外 grid
参数的 pallas_call
来执行它。在 GPU 上,我们可以直接调用内核,如下所示
# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
TPU 区分矢量和标量内存空间,在这种情况下,输出必须放置在标量内存 (TPUMemorySpace.SMEM
) 中,因为 i
是标量。有关更多详细信息,请阅读 TPU 及其内存空间。要在 TPU 上调用上述内核,请运行
# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
网格语义#
在 GPU 上,每个程序都在单独的线程上并行执行。因此,我们需要考虑在写入 HBM 时的竞争条件。一种合理的方法是以这样一种方式编写我们的内核,即不同的程序写入 HBM 中不相交的位置,以避免这些并行写入。另一方面,并行化计算是我们如何真正快速地执行矩阵乘法等操作的方式。
相比之下,TPU 的工作方式类似于非常宽的 SIMD 机器。一些 TPU 模型包含多个核心,但在许多情况下,可以将 TPU 视为单线程处理器。TPU 上的网格可以在并行和顺序维度的组合中指定,其中顺序维度保证按顺序运行。
您可以在 网格,又名循环中的内核 和 值得注意的属性和限制 中阅读更多详细信息。
块规范示例#
考虑到 grid
和 program_id
,Pallas 提供了一种抽象,可以处理在许多内核中看到的一些常见索引模式。为了建立直觉,让我们尝试实现矩阵乘法。
在 Pallas 中实现矩阵乘法的一种简单策略是递归地实现它。我们知道我们的底层硬件支持小型矩阵乘法(使用 GPU 和 TPU 张量核),因此我们只是用较小的矩阵乘法来表示一个大的矩阵乘法。
假设我们有输入矩阵 \(X\) 和 \(Y\),并且正在计算 \(Z = XY\)。我们首先将 \(X\) 和 \(Y\) 表示为分块矩阵。\(X\) 将具有“行”块,\(Y\) 将具有“列”块。
我们的策略是,因为 \(Z\) 也是一个分块矩阵,我们可以将 Pallas 内核中的每个程序分配给其中一个输出块。计算每个输出块对应于在 \(X\) 的“行”块和 \(Y\) 的“列”块之间进行较小的矩阵乘法。
为了表达这种模式,我们使用 BlockSpec
。 BlockSpec
为每个输入和输出指定一个块形状,以及一个“索引映射”函数,该函数将一组程序索引映射到块索引。
BlockSpec
的可视化
举一个具体的例子,假设我们想要将两个 (1024, 1024)
的矩阵 x
和 y
相乘,得到 z
,并且希望将计算并行化为 4 路。我们将 z
分割成 4 个 (512, 512)
的块,其中每个块都通过 (512, 1024) x (1024, 512)
的矩阵乘法计算得出。为了表达这一点,我们首先使用一个 (2, 2)
的网格(每个程序一个块)。
对于 x
,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0))
– 这会将 x
分割成 “行” 块。要理解这一点,请观察程序实例 (1, 0)
和 (1, 1)
如何在 x
中选择 (1, 0)
块。对于 y
,我们使用转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))
。最后,对于 z
,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))
。
这些 BlockSpec
通过 in_specs
和 out_specs
传递到 pallas_call
中。
有关 BlockSpec
的更多详细信息,请参阅BlockSpec,即如何将输入分块。
在底层,pallas_call
会自动将您的输入和输出分割成每个块的 Ref
,这些 Ref
将被传递到内核中。
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
请注意,这只是一个非常简单的矩阵乘法实现,但可以将其视为各种优化的起点。让我们为矩阵乘法添加一个额外的功能:融合激活。这实际上非常容易!只需将高阶激活函数传递到内核中即可。
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
最后,让我们强调一下 Pallas 的一个很酷的功能:它可以与 jax.vmap
组合使用!要将此矩阵乘法转换为批量版本,我们只需要 vmap
它即可。
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))