Pallas 快速入门#

Pallas 是 JAX 的扩展,它允许编写用于 GPU 和 TPU 的自定义内核。Pallas 允许您使用相同的 JAX 函数和 API,但在更低的抽象级别上运行。

具体来说,Pallas 需要用户考虑内存访问以及如何在硬件加速器中的多个计算单元之间划分计算。在 GPU 上,Pallas 降级到 Triton,而在 TPU 上,Pallas 降级到 Mosaic。

让我们深入了解一些示例。

注意:Pallas 仍然是一个实验性 API,您可能会受到更改的影响!

Pallas 中的 Hello World#

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

我们首先在 Pallas 中编写“Hello World”,这是一个将两个向量相加的内核。

def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

Ref 类型

让我们稍微分析一下这个函数。与您可能编写的大多数 JAX 函数不同,它不接受 jax.Array 作为输入,也不返回任何值。相反,它接受Ref 对象作为输入。请注意,我们也没有任何输出,但我们得到了一个 o_ref,它对应于所需的输出。

Ref 中读取

在主体中,我们首先从 x_refy_ref 中读取,由 [...] 表示(省略号表示我们正在读取整个 Ref;或者我们也可以使用 x_ref[:])。像这样从 Ref 中读取会返回一个 jax.Array

写入到 Refs

然后我们将 x + y 写入到 o_ref。在 JAX 的历史中,一直不支持变异 - jax.Arrays 是不可变的!Refs 是新(实验性)类型,允许在特定情况下进行变异。我们可以将写入到 Ref 解释为变异其底层缓冲区。

因此,我们编写了一个称为“内核”的东西,我们将其定义为一个程序,该程序将作为加速器上的原子执行单元运行,没有任何与主机的交互。我们如何在 JAX 计算中调用它?我们使用 pallas_call 高阶函数。

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pallas_call 将 Pallas 内核函数提升为一个操作,该操作可以作为更大 JAX 程序的一部分进行调用。但是,要做到这一点,它需要一些更多细节。在这里,我们指定 out_shape,这是一个具有 .shape.dtype(或它们的列表)的对象。out_shape 确定我们 add_vector_kernelo_ref 的形状/数据类型。

pallas_call 返回一个函数,该函数接收并返回 jax.Arrays。

这里到底发生了什么?

到目前为止,我们已经描述了如何考虑 Pallas 内核,但我们实际上已经完成的是编写一个非常接近计算单元执行的函数。

在 GPU 上,x_ref 对应于高带宽内存 (HBM) 中的值,当我们执行 x_ref[...] 时,我们将值从 HBM 复制到静态 RAM (SRAM)(总的来说,这是一个代价高昂的操作!)。然后,我们使用 GPU 矢量计算执行加法,然后将 SRAM 中的结果值复制回 HBM。

在 TPU 上,我们做了一些不同的事情。在内核执行之前,我们将值从 HBM 提取到 SRAM 中。x_ref 因此对应于 SRAM 中的值,当我们执行 x_ref[...] 时,我们将值从 SRAM 复制到寄存器中。然后,我们使用 TPU 矢量计算执行加法,然后将结果值复制回 SRAM。内核执行后,SRAM 值被复制回 HBM。

我们正在编写特定于后端的 Pallas 指南。敬请期待!

Pallas 编程模型#

在我们的“hello world”示例中,我们编写了一个非常简单的内核。它利用了我们的 8 大小数组可以舒适地放入硬件加速器 SRAM 的事实。在大多数现实世界的应用程序中,情况并非如此!

编写 Pallas 内核的一部分是考虑如何获取驻留在高带宽内存 (HBM,也称为 DRAM) 中的大型数组,并表达对这些数组的“块”进行操作的计算,这些块可以放入 SRAM 中。

通过示例进行网格化#

为了自动“分割”输入和输出,您需要向 pallas_call 提供一个 gridBlockSpecs。

一个 grid 是一个整数元组(例如 ()(2, 3, 4)(8,)),它指定一个迭代空间。例如,网格 (4, 5) 将具有 20 个元素:(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)。我们对每个元素运行一次内核函数,这是一种单程序多数据 (SPMD) 编程风格。

A visualization of a 2D grid

二维网格

当我们向 pallas_call 提供一个 grid 时,内核将执行与 prod(grid) 一样多次。每次调用被称为“程序”。要访问内核当前正在执行的程序(即网格的哪个元素),我们使用 program_id(axis=...)。例如,对于调用 (1, 2)program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2

这是一个使用 gridprogram_id 的内核示例。

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

现在,我们使用 pallas_call 以及额外的 grid 参数来执行它。

def iota(size: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
                        grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

在 GPU 上,每个程序都在独立的线程上并行执行。因此,我们需要考虑写入 HBM 的竞争条件。一个合理的方法是以这样一种方式编写我们的内核,即不同的程序写入 HBM 中的非重叠位置,以避免这些并行写入。另一方面,并行化计算是我们如何快速执行矩阵乘法等操作的方法。

在 TPU 上,程序以并行和顺序的组合(取决于体系结构)执行,因此存在一些不同的考虑因素。

您可以在 grid,也称为循环中的内核 中阅读更多详细信息。

通过示例进行块规范#

有了 gridprogram_id,Pallas 提供了一个抽象,它负责处理许多内核中看到的常见索引模式。为了建立直觉,让我们尝试实现矩阵乘法。

在 Pallas 中实现矩阵乘法的一个简单策略是递归地实现它。我们知道我们底层的硬件支持小型矩阵乘法(使用 GPU 和 TPU 张量核心),因此我们只需用更小的矩阵乘法来表达一个大型矩阵乘法。

假设我们有输入矩阵 \(X\)\(Y\),并计算 \(Z = XY\)。我们首先将 \(X\)\(Y\) 表达为块矩阵。\(X\) 将具有“行”块,而 \(Y\) 将具有“列”块。

\[\begin{split} \begin{align*} X = \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \end{align*} \end{split}\]
\[ \begin{align*} Y = \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \end{align*} \]
\[\begin{split} \begin{align*} Z &= \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \begin{matrix} \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \\ ~ \end{matrix} \\ &= \begin{bmatrix} X_0 Y_0 & X_0 Y_1 \\ X_1 Y_0 & X_1 Y_1 \end{bmatrix} \end{align*} \end{split}\]

我们的策略是,因为 \(Z\) 也是一个块矩阵,所以我们可以将 Pallas 内核中的每个程序分配给一个输出块。计算每个输出块对应于对 \(X\) 的“行”块和 \(Y\) 的“列”块进行更小的矩阵乘法。

为了表达这种模式,我们使用 BlockSpecs。一个 BlockSpec 为每个输入和输出指定一个块形状,并指定一个“索引映射”函数,该函数将一组程序索引映射到一个块索引。

A visualization of a BlockSpec`

一个 BlockSpec 的可视化

以一个具体的例子来说,假设我们想要将两个 (1024, 1024) 矩阵 xy 乘在一起以生成 z,并且想要将计算并行化 4 次。我们将 z 分成 4 个 (512, 512) 块,其中每个块都是使用 (512, 1024) x (1024, 512) 矩阵乘法计算的。为了表达这一点,我们首先使用一个 (2, 2) 网格(每个程序一个块)。

对于 x,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0)) - 这将 x 分割成“行”块。看看程序实例 (1, 0)(1, 1) 如何在 x 中选择 (1, 0) 块。对于 y,我们使用一个转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))。最后,对于 z,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))

这些 BlockSpecs 通过 in_specsout_specs 传递给 pallas_call

有关 BlockSpecs 的更多详细信息,请参见 BlockSpec,也称为如何将输入分割成块

在幕后,pallas_call 将自动将您的输入和输出分割成 Refs,用于将传递到内核的每个块。

def matmul_kernel(x_ref, y_ref, z_ref):
  z_ref[...] = x_ref[...] @ y_ref[...]

def matmul(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
    )
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)

请注意,这是一个非常简单的矩阵乘法实现,但可以将其视为各种优化类型的起点。 让我们为矩阵乘法添加一个额外的功能:融合激活。 实际上很简单! 只需将高阶激活函数传递到内核中即可。

def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
  z_ref[...] = activation(x_ref[...] @ y_ref[...])

def matmul(x: jax.Array, y: jax.Array, *, activation):
  return pl.pallas_call(
    partial(matmul_kernel, activation=activation),
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
    ),
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))

总之,让我们重点介绍 Pallas 的一个很酷的功能:它与 jax.vmap 兼容! 要将此矩阵乘法转换为批处理版本,我们只需要 vmap 它。

k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))