Pallas 设计#

在本文件中,我们解释了最初的 Pallas 设计。这是早期设计决策的一些快照,自那时以来,Pallas 的特定 API 可能已发生变化。

简介#

JAX 正在被用于各种各样的工作负载,从大型机器学习到科学计算。JAX 的成功故事也是 XLA 的成功故事,XLA 是 JAX 的主要编译器,它为加速器编译 JAX 程序,并使 JAX 能够扩展到最大的 ML 模型。JAX 在 XLA 的表示 HLO 中描述逻辑计算。HLO 描述了计算在逻辑上的执行方式,而不是物理上的。给定一个逻辑 HLO 计算,XLA 决定如何物理地执行该计算。对于各种各样的 ML 应用,XLA 在编译用户程序方面做得很好,但不可避免地一些用户会遇到 XLA 的局限性。在这些情况下,我们需要提供一个“逃生通道”,允许专家编写手动调整的内核,在那个时间点上超过 XLA 的性能。此外,ML 系统研究的进步需要一段时间才能融入 XLA,用户通常希望领先于他们。随着时间的推移,编译器可以合并通过手动调整的内核在实验中验证的优化。

XLA 确实提供了 CustomCall 机制作为逃生通道,但它要求用户编写 C++,并且在 GPU 上它要求用户学习 CUDA 编程模型。对于许多机器学习 GPU 内核(如矩阵乘法)来说,CUDA 编程模型可以说是太底层了,即使是专家用户也会难以使用 CUDA 实现高效的矩阵乘法或多头注意力。不仅如此,JAX 用户通常熟悉 Python 和 NumPy 风格的数组编程,它不涉及编写任何 C++ 代码或考虑 GPU 并行性。所有流行的机器学习框架都共享这个想法:用高级操作(通常)操纵数组,例如 matmulconvolution。不幸的是,这意味着通过 CustomCall 实现自定义操作是一项巨大的投资,可能需要学习 C++ 和/或 GPU 编程。

Triton 是一个由 OpenAI 构建和维护的 GPU 编译器,它席卷了 ML 编译器世界。Triton 同时提供了最好的两方面:用于 GPU 内核的基于数组的编程模型。Triton 是 PyTorch 2.0 中 torch.compile 的主要代码生成路线,通过 Torch Inductor 库实现。Triton 积极隐藏 GPU 编程的某些方面,以实现更易于访问的编程模型,该模型可以从 Python 使用,并从更高级别的表示生成优化代码。虽然 GPU 比 Triton 提供的更灵活,但在 ML 领域,Triton 似乎对许多应用来说已经足够表达了。

在这篇文档中,我们描述了 Pallas,它是 JAX 的扩展,它允许使用类似 Triton 的模型为 GPU 和 TPU 编写内核编程。基于 JAX 的内核语言提供了一些优势

  • 虽然 Triton 向用户公开了一个类似 TPU 的编程模型,即为 L1 缓存中的数组块编写程序,但它足够专门化到 GPU,以至于我们不能直接为 TPU 编译 Triton。例如,Triton 提供了专门用于处理不一定要在 TPU 上有意义的并行写入的原子操作。更高级别的前端可以抽象出平台的细节,同时只显示基于块的编程模型。因此,内核将在不同的硬件平台上可移植。

  • JAX 作为数值计算的基于跟踪的前端既成熟又使用广泛。通过将内核编程语言嵌入到 JAX 本身,我们可以重复使用 JAX 的跟踪基础设施,并提供一个用户已经熟悉的类似 NumPy 的前端。

  • JAX 变换是其成功的关键,它允许用户表达简单的程序,但将其变换以实现复杂的功能。我们可以利用相同的变换(vmap、jvp 等)来变换用户编写的内核。

现在的问题是:JAX 是否适合内核语言?我们认为是。Triton 证明了一种数组编程语言可以用于编写 GPU 内核,而 JAX 就是这样。JAX 还被证明是编译器和程序变换的灵活前端。

我们对 Pallas 的描述如下:首先,我们描述了我们扩展 JAX 以支持编写自定义内核的方式。然后,我们展示了如何将 Pallas 降低到 Triton 和 Mosaic。最后,我们描述了通过 JAX 变换变换 Pallas 内核的现有和潜在方式。

Pallas lowering path Pallas 降低路径可视化

Pallas:扩展 JAX 用于内核#

我们想说明的关键点是,Pallas 仅仅是 JAX,带有一些扩展

  1. 用户现在在他们的 JAX 代码中使用称为 Ref 的引用类型。这使得用户可以更精确地控制 JAX 中的内存访问和布局,更接近于物理布局。

  2. 用户使用 JAX 原语的子集以及一组 Pallas 特定的原语编写他们的 JAX 程序。

  3. 用户通过一个特殊的 pallas_call 高阶函数将其 Pallas 内核嵌入到外部 JAX 程序中,该函数在映射中执行内核。它类似于 pmapshard_map,只是引用了共享内存。

我们将通过示例逐个介绍这三个扩展。

请注意,这些 API 仍然处于实验阶段,可能会发生变化。

引用类型#

让我们看一个用于添加两个向量的 Pallas 程序示例

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)

与常规 JAX 程序不同,add_kernel 不接收不可变的数组参数。相反,它提供了可以从其读取并在其上就地更新的引用,使用类似 NumPy 的语法。 Ref 不是 Pallas 特定的概念,它们是引入到 JAX 中来表示有状态计算的。但是,当编写操作可变内存的内核时,我们可以利用它们。

Pallas 内核不仅接收与内核输入相对应的 Ref,还接收输出的 Ref(在 pallas_call 中通过 out_shape 指定)。 Ref 是特殊类型,不能在不首先读取的情况下传递到通常的 JAX 原语集中。当从 Ref 读取时,您将得到一个 JAX Array 类型,并且您必须将 Array 写入 Ref

从 Ref 读取/写入#

Ref 读取对应于将数组加载到内存层次结构的最低级别(GPU 上的 L1 缓存和 TPU 上的向量寄存器)。写入 Ref 类似。

def f(x_ref, o_ref):
  # Using vanilla Python indexing
  x = x_ref[0, 2:5, :]
  # Or via Numpy advanced int indexing
  o_ref[jnp.arange(3), :] = x

# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
  # Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
  x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]

可以通过类似的 __setitem__ 样式索引写入 Ref

其他形式的索引(例如,动态切片)可以通过 pallas.loadpallas.store 完成,这些是新设计的 JAX 原语,旨在使从内存加载/存储到内存更容易。我们将在后面讨论这些新的原语。

使用新的 Pallas 原语扩展 JAX#

由于 JAX 是为 HLO 设计的,因此 JAX 原语集与 HLO 操作集非常相似。针对新的编译器(例如 Triton 或 Mosaic)意味着我们可能需要用特定于新编译器的新的原语来补充 JAX 的原语。同时,我们可能无法降低所有 JAX 原语,因此我们需要将其限制为子集。

由于 Pallas 最初是针对 Triton 设计的,因此我们提供了一组针对 Triton 编程模型的新原语。正如我们将在后面展示的那样,我们也可以将这些原语降低到 Mosaic。

pallas.loadpallas.store#

pallas.loadpallas.store 是允许从内存加载和存储到内存的原语。与 __getitem____setitem__ 不同,它们更灵活,但代价是更冗长。具体来说,您可以使用 pallas.dynamic_slice(简称 pallas.ds)结构(可能应该上游到 JAX,以便与 Ref __getitem____setitem__ 一起使用)。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
  # Using integer indexing automatically broadcasts
  x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
  # You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
  pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)

pallas.loadpallas.store 也支持通过 mask 参数进行掩码。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  idx = jnp.arange(8)
  mask = idx < 5
  x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))

在进行越界加载/存储时,掩码非常重要。掩码的操作语义可以由编译器确定(如果我们正确理解文档,Triton 会避免从掩码内存进行读取/写入)。

pallas.program_idpallas.num_programs#

正如我们将很快看到的那样,我们将多次执行相同的 Pallas 内核(并行或在管道中,具体取决于后端)。这些新的原语告诉我们内核执行的“位置”。

pallas.program_id 接受一个 axis 参数,它告诉我们内核当前在多维网格的哪个轴上执行(类似于 CUDA 编程中的 threadIdjax.pmap 中的 lax.axis_index)。请注意,我们目前借用了 Triton 的“程序”术语,将来我们可能希望将其更改为对 JAX 用户来说更熟悉的术语。

def f(x_ref, o_ref):
  i = pl.program_id(axis=0)  # execution index in the first axis of the grid
  o_ref[i] = jnp.exp(x_ref[i])

pallas.num_programs 也接受一个 axis,并返回该轴的网格大小。

请注意,虽然 program_idnum_programs 是 Triton 特定的术语,但它们很容易推广到在 TPU 上也有意义。

在 Pallas 中使用 JAX 原语的子集#

由于我们正在编写内核,而不是高级 HLO 程序,因此一些 JAX 原语可能无法在底层基板上有效地表示。但是,我们知道我们可以支持大多数逐元素操作、简单的点积和 JAX 控制流。

虽然我们还没有完全确定我们可以在 Pallas 内核中支持的所有 JAX 原语,但我们肯定可以识别出一些难以降低或不太可能有用。

  • conv_general - 卷积通常不出现在底层硬件中作为原语。

  • gather/scatter - 底层编译器可能不支持非连续内存读取和写入

使用 pallas_call 执行 Pallas 内核#

现在我们已经编写了 Pallas 内核(也称为使用 Ref 的 JAX 以及额外的 Pallas 原语),如何将它们在 GPU 或 TPU 上执行呢?我们使用 pallas_call,这是一个高阶函数(类似于 jax.jitjax.pmap),用于执行内核。

pallas_call 的签名如下

def pallas_call(
    kernel: Callable,
    out_shape: Sequence[jax.ShapeDtypeStruct],
    *,
    in_specs: Sequence[Spec],
    out_specs: Sequence[Spec],
    grid: Optional[Tuple[int, ...]] = None) -> Callable:
  ...

当我们向 pallas_call 提供内核时,我们还提供了其他信息。第一个是 out_shape,它告诉内核输出是什么样的(pallas_call 将传递一个对应于这些的 Ref 到内核中,以便写入)。其余信息(in_specsout_specsgrid)是关于内核如何在加速器上调度的信息。

pallas_call 的(大致)语义如下

def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
  def execute(*args):
    outputs = map(empty_ref, out_shape)
    grid_indices = map(range, grid)
    for indices in itertools.product(*grid_indices): # Could run in parallel!
      local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
                      zip(args, in_specs)]
      local_outputs = [out_spec.transform(arg, indices) for arg, out_spec  in
                       zip(outputs, out_specs)]
      kernel(*local_inputs, *local_outputs) # writes to outputs
  return execute

具体来说,pallas_call 将在网格迭代空间中“循环”,应用一个变换,该变换通过 in_specsout_specs 指定对输入和输出进行变换。在每次迭代中,内核将在变换后的输入和输出上调用。请注意,“循环”在迭代空间上的执行可能是并行的(例如,在 GPU 上)。pallas_call 也不保证迭代空间中循环迭代的顺序,只是保证迭代空间的每个成员都会被循环。像 Triton 和 Mosaic 这样的编译器将在网格中具有更具体的运行语义。

变换函数#

pallas_callin_specsout_specs 参数允许以某种方式变换输入和输出。Pallas 目前提供的两个选项是身份变换(输入和输出保持不变),以及 BlockSpec,它获取由循环索引确定的 Ref 的固定大小的切片。

BlockSpec 获取一个 index_map 函数和一个 block_shape。逻辑上,它获取一个数组,并沿着每个轴将其切片成 block_shape 大小的块。 index_map 函数获取循环索引(来自网格索引集),并将它们映射到块索引。变换函数将 Ref 转换为在相应块处的 Ref 的逻辑视图。当我们在 block_shape 中的条目中指定 None 时,它对应于“映射”该维度,将其从内核内的块中移除。

class BlockSpec:
  index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
  block_shape: Tuple[Optional[int], ...]

  def transform(self, ref, *loop_indices):
    block_indices = self.transform_function(loop_indices)
    # Returns a view of `ref` starting at `block_indices` of shape self.block_shape
    ...

我们还可以想象其他与 pallas_call 一起使用的 Spec,例如一个 Spec,它对应于重叠窗口,例如,用于实现卷积。

Pallas 作为前端的直接好处#

通过为内核编写提供一个 JAX 前端,我们可以立即获得一些好处。

更灵活的前端#

第一个好处是,JAX 用户已经习惯了使用 JAX 及其基于跟踪的变换带来的好处(和局限性)。这意味着用户可以在编写 Pallas 内核时使用闭包和其他熟悉的 Python 结构。这与现有的基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器不同。例如,这使得 Pallas 比 Triton 更适合模板化。

请查看以下示例,说明我们如何在 Python 中使用高阶函数来模板化一个内核。

def make_kernel(eltwise_kernel):
  def add(x_ref, y_ref, o_ref):
    x = pl.load(x_ref, ())
    y = pl.load(y_ref, ())
    pl.store(o_ref, (), eltwise_kernel(x + y))
  return add

kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)

pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)

仿真模式#

通过将内核表示为带有 JAX 原语和一些新的 Pallas 原语的程序,我们也可以将 Pallas 程序直接降低到 StableHLO,并使用 XLA 进行编译和执行。具体来说,一个 pallas_call 可以实现为在网格上的 lax.scan。这使我们能够在任何支持 XLA 的平台(甚至 CPU!)上开发 GPU 或 TPU 内核,并使用 JAX/XLA 调试工具(例如 jax.debug.print)进行调试。我们还可以使用更可靠和经过更好测试的 XLA 数字来验证 Triton 和 Mosaic 编译器的正确性。还可以想象通过扰乱 scan 的排序来模拟在 GPU 上发生的并行读写。

示例#

add#

我们修改了 add_kernel 示例,使用 BlockSpec 在 (2,)-sized 块上操作。

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
    add_kernel,
    out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
    in_specs=[
        pl.BlockSpec((2,), lambda i: i),
        pl.BlockSpec((2,), lambda i: i)
    ],
    out_specs=pl.BlockSpec((2,), lambda i: i),
    grid=(4,))
add(x, y)

模板化矩阵乘法#

在这个例子中,我们通过对输入数组中的行和列块进行展开累积来计算输出的块。我们使用高阶函数将激活函数内联到内核主体中,以便我们可以发出融合内核。

def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
  acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
  for k in range(x_ref.shape[1] // block_k):
    x = x_ref[:, k*block_k:(k+1)*block_k]
    y = y_ref[k*block_k:(k+1)*block_k, :]
    acc += x @ y
  o_ref[:, :] = activation(acc).astype(o_ref.dtype)

x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128

@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
  block_m, block_n, block_k = block_shape
  fused_matmul = pl.pallas_call(
      partial(matmul_kernel, block_k=block_k, activation=activation),
      out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
      in_specs=[
          pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
          pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
      ],
      out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
      grid=(4, 4),
  )
  return fused_matmul(x, y)

z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)

降低 Pallas#

在用户表达完 Pallas 内核后,我们将它们降低到不同的表示,具体取决于目标后端。在 GPU 上,我们将 Pallas 降低到 Triton IR,在 TPU 上,我们将 Pallas 降低到 Mosaic。

将 Pallas 降低到 Triton 以用于 GPU#

将 Pallas 降低到 Triton 很容易,因为 Pallas 的设计考虑了 Triton 作为目标语言。Pallas 和 Triton 之间的主要区别在于 Triton 没有 BlockSpec 的概念,并且在进行内存加载和存储时使用指针而不是索引。

Triton 在其语言中支持指针作为数组元素类型,并且在 Triton 中,您可以从指针数组中加载和存储。在 Pallas 中,当给定一个形状为 (4, 5)Refx_ref,然后执行类似 x_ref[3, 2] 的操作时,我们需要将其降低到计算 x_ref 中相应行主位置的 Triton 指针(即,执行 5 * 3 + 2 * 1)。类似地,当我们将切片降低到 Triton 时,例如 x_ref[4, :],我们需要生成一个指针数组 5 * 4 + jnp.arange(3)

除此之外,降低到 Triton 相当简单。JAX 点积可以降低到 Triton 点积,JAX 一元原语可以降低到其 Triton 等效项。Triton 的原子操作通过新的 Pallas 原子原语降低。

将 Pallas 降低到 Mosaic 以用于 TPU#

Mosaic 使用(大部分)标准方言 MLIR,并发出 LLO 以便为 TPU 编译。Pallas 可以通过将 JAX 原语转换为 MLIR(主要是 vectorarith 方言)来降低到 Mosaic。 BlockSpec 可以转换为管道调度(即 Mosaic 中的 transform_func)。

变换 Pallas#

一个自然的问题是,JAX 变换如何与 Pallas 内核交互?主要有两种方式:Pallas 内核内部的变换和 Pallas 内核外部的变换。

Pallas 内核内部的变换实际上应该“正常工作”,只要我们能够降低变换后的代码。例如,我们可以在 JAX 内核内使用 jax.grad(jnp.sin)(...),因为我们可以将 cos 降低到 Triton 和 Mosaic。但是,我们可能无法降低 jax.vmap(lax.dynamic_slice),因为它可能会变成我们无法降低的收集。

来自外部 JAX 程序的 Pallas 内核的变换也许更有趣。我们如何处理诸如 vmap(pallas_call)grad(pallas_call) 之类的事情呢?

vmap-of-pallas_call#

vmap 自动向量化 JAX 程序。虽然内核编写者可能希望对批处理内核的行为与未批处理内核的不同方式进行精确控制,但我们可以为 pallas_call 提供一个合理的默认 vmap 规则,同时提供 jax.custom_vmap 自定义机制。当 pallas_callvmap 时,我们将为 pallas_call 添加一个额外的网格维度,对应于新的批处理维度,并变换 BlockSpec 以处理沿着该维度的索引。

grad-of-pallas_call#

grad of pallas_call 使内核能够自动微分。jax.grad 分解为三个不同变换的应用:jvppartial_evaltranspose。原则上,在为 pallas_call 实现这些规则时,我们可以重新使用 JAX 的大部分基础设施(因为它在行为上与现有的 JAX 高阶原语非常相似)。

然而,内核的自动微分可能会导致性能下降,因为内存访问的方式被转置了。如果我们编写一个具有重叠并行读取和不相交并行写入的 GPU 内核,我们将自动将其转置为一个内核,该内核具有重叠并行写入(在原子执行时速度较慢)和不相交并行读取。为了发出一个更好地利用共享内存并行性的内核,我们需要重新排序循环并改变内核的矢量化方式。不幸的是,我们在 Pallas 中没有一个适合这种操作的程序表示。高效自动微分内核的一个潜在方向是探索不同的表示,也许类似于 Dex 中的表示。我们也可以看看 Enzyme 是如何解决这个问题的。然而,Pallas 内核的 AD 对于那些能有效地转置的内核类别仍然有用(例如元素级内核)。

总的来说,jax.custom_vjp 是一个可行的避风港,用于表达与 jax.grad 一起工作的 Pallas 内核。

其他转换#

我们可以想象其他 JAX 转换应用于 Pallas 内核,而我们还没有明确探索。例如,checkify 是一个执行函数式错误处理的 JAX 转换。我们可以想象使用 checkify 与 pallas_call 一起使用,允许从 GPU 内核中排出错误代码,这些代码指示是否产生了 OOB 访问或 NaN。

另一个可以与之集成的潜在转换是 custom_partitioning,以使可自动分区内核能够与 pjit 一起使用。