矩阵乘法#

在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将讨论如何考虑 TPU 上的 matmul 性能以及如何模板化 matmul 内核以融合操作。

#@title Imports
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩阵乘法是现代深度学习和语言建模核心的基本线性代数运算。我们希望使用像 TPU 和 GPU 这样的专用加速器尽可能快地进行 matmul,它们都有用于快速矩阵乘法的专用单元。

为了有效地利用 TPU 进行矩阵乘法,我们需要涵盖一些背景概念:块矩阵乘法、平铺和流水线。

块矩阵乘法#

假设我们想要实现 matmul(x, y),它通常将一个 (m, k) 数组与一个 (k, n) 数组相乘,但有一个小小的变化。我们只允许使用原始的 matmul_small,它将小矩阵相乘(比如 m, k, n <= 256)。我们该如何实现呢?

矩阵乘法的一个很好的特性是,输出的每个块都可以表示为输入行块和列块的几个较小的矩阵乘法的和。形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\) 以及输出 \(z \in \mathbb{R}^{m \times n}\),我们将它们沿着大小为 \(b_m, b_k, b_n\) 的维度分解为块。

例如,\(x\) 将被分解为

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们可以类似地分解 \(y\)\(z\)。)

对于特定的输出块 \(z_{ij}\),我们可以将其计算为

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每个输出块 \(z_{ij}\) 是几个较小的块矩阵乘法 \(x_{ik} y_{kj}\) 的和。以下是在 NumPy 中实现此算法的方法

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我们的 block_matmul 函数现在应该可以处理大于 256 的输入(尽管我们假设我们的输入维度可以被 256 整除)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 通过观察到每个大小为 (bm, bn) 的输出块可以通过累积几个 (bm, bk) x (bk, bn) 大小的矩阵乘法来计算,从而将矩阵乘法分解为许多较小的矩阵乘法。

TPU 和 GPU 就是这样进行矩阵乘法的!它们原生支持类似于 matmul_small 的小矩阵乘法,因此为了在进行更大的矩阵乘法时利用此硬件,我们将应用 block_matmul 分解。

分块和流水线#

之前的指南中,我们介绍了如何在 Pallas 中进行计算分块和流水线操作。为了确保我们的计算单元始终在工作并且永远不会因内存传输而停滞,我们将内核的下一次迭代的内存传输与当前的迭代重叠。

在 Pallas 中,我们通过 BlockSpec 和一个 grid 来指定这一点。请注意,我们在块矩阵乘法算法中已经有一个嵌套的 for 循环。我们可以通过 Pallas 中的 grid 来指定它。块矩阵乘法中的切片也可以通过 BlockSpec 来指定。

你的第一个矩阵乘法内核#

总而言之,这是一个块矩阵乘法内核的实现,该内核将内存传输与计算进行流水线操作。我们创建一个 3 维网格,对应于 NumPy 代码中的 3 个嵌套循环。请注意,虽然 MXU 只能进行小块的乘法运算,但 Pallas 会自动获取更大的块,并自动将它们分块到 MXU 上。

网格的最后一个维度对应于矩阵乘法的收缩维度,并且是一个归约维度,因此我们需要确保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩阵乘法性能#

让我们思考一下如何分析矩阵乘法的性能。当我们考虑矩阵乘法性能时,我们通常会关心两件事:浮点运算总数 (FLOP) 和内存带宽的使用量。从关于 TPU 和流水线的指南中,我们看到为了使用 TPU(和通用 ML 加速器)上的高效计算单元,我们需要将输入从 HBM 复制到 VMEM,更接近计算单元。往返于 HBM 的这种复制需要时间,并且高效的内核希望大部分时间都花费在实际计算上,而不是等待这些传输。内存带宽衡量此数据传输的速率。

快速说明:在本指南中,我们将讨论浮点运算,但需要区分 FLOP 与 FLOP/s。当我们说“FLOP”时,我们指的是“浮点运算”,即运算的数量。当我们说“FLOP/s”时,我们指的是“每秒浮点运算”,即执行浮点运算的速率

(m, k) x (k, n) 矩阵乘法中的 FLOP 数量(近似)为 2 * m * k * n。(技术上是 n * m * (2k - 1),但对于足够大的 k,我们的近似值就足够了。)

矩阵乘法的最小内存带宽使用量(假设为 float32)是输入总大小(复制到 VMEM)加上输出大小(复制到 HBM)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 字节/float32。如果我们多次重新读取输入,内存使用量可能会更大,这种情况经常发生。

一个观察结果是,matmul FLOP 的数量是其输入的三次方,而最小带宽使用量是其输入的二次方。直观地说,这意味着 FLOP 的增长速度比带宽使用量快,这意味着我们的 matmul 越大,相对于复制,我们拥有的计算就越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

既然我们可以计算矩阵乘法的 FLOP 总数和(最小)内存带宽使用量,让我们看看真正的 TPU 可以处理什么。

此笔记本是在 TPU v5e 芯片上运行的,因此我们将使用 v5e 的数字(如果您正在运行此笔记本,您的数字可能会有所不同)。TPU v5e 具有197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以得到一个界限,即“FLOP/内存带宽使用量”的比率在达到 IO 限制之前可以降低多少(在 TPU v5e 上约为 240 FLOP/字节)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

粗略地说,这些数字告诉我们,matmul 的 FLOP 应该花费 2 * m * k * n / (197 TFLOP/s) 秒,而往返于 VMEM 的复制应该花费 (m*k + k*n + m*n) * 4 字节 / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

这个基本计算告诉我们,我们将能够以多高的效率使用我们的 MXU。如果我们的 matmul 操作强度低于我们的芯片所能承受的,那么我们的计算将受到内存限制,即我们的计算单元将在等待值传输时处于空闲状态。如果 matmul 强度高于芯片所能承受的,那么我们将受到计算限制

由于 matmul FLOP 的数量是其输入大小的三次方,而内存带宽使用量是其输入的二次方,因此我们预计随着我们变得越来越大,我们将受到计算限制,但这个交叉点非常重要!假设我们正在执行 (1024, 1024) x (1024, 1024) float32 矩阵乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我们的矩阵乘法浮点运算强度低于芯片的性能。这不太好!这种类型的矩阵乘法很可能受内存限制。然而,如果我们的输入和输出更大呢?当我们的矩阵乘法足够大时,我们将在某个时刻从内存受限转变为计算受限。例如,如果我们有一个矩阵乘法,其中 m = k = n,那么当 2m**3 / 12m**2 > 240 或当 m = k = n > 1440 时,我们将在 (TPU v5e 上) 跨越这个界限。

bfloat16 矩阵乘法#

为了使矩阵乘法更容易在 TPU 上计算受限,我们还可以为输入和输出使用较小的数据类型。我们之前的示例使用了 float32 输入和输出,但 TPU v5e 也支持 bfloat16 数据类型(一种 16 位浮点格式,也称为 bf16)用于矩阵乘法。在 TPU v5e 上,我们将具有相同的 FLOP/s,但将减半我们的内存带宽使用量。这使得较小矩阵更容易计算受限。让我们看看一个 1024 x 1024 x 1024 bf16 矩阵乘法的强度是多少

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我们现在有一个计算受限的矩阵乘法!

让我们为我们的矩阵乘法内核添加 bf16 支持。

原生 MXU bf16 矩阵乘法例程接受两个输入 bf16 矩阵,并在 f32 中累积它们。我们将通过将 preferred_element_type=jnp.float32 传递到 jnp.matmul 中来触发此例程。我们还需要一个 f32 中的累加器 Ref。然后,在将其写回 HBM 之前,我们将把输出向下转换为 bf16。这样,我们不会损失任何精度,不会进行任何额外的强制转换,并且仍然保留 bf16 内存带宽节省。

请注意,现在分配暂存空间的唯一方法是通过 pltpu.PrefetchScalarGridSpec。现在不要担心它具体做什么 – 您现在需要知道的是,它允许您在 VMEM 中分配暂存空间。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

流水线内核的性能#

我们上面关于 FLOP 与内存使用情况的分析是在粗略的尺度上应用的,即当我们查看整个矩阵乘法的大小时。但是,请记住,在实践中,我们正在流水线执行分块矩阵乘法,这意味着我们有一个循环,在其中使用较小的块进行矩阵乘法。

这意味着我们实际上关心的是内核的每个单独实例的 FLOP 与内存带宽使用情况,而不是全局 FLOP 与内存带宽使用情况。因此,块大小 bmbkbn 对于性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bmbkbn,我们将受到内存限制,因为每次调用内核时,我们的 FLOP 太少,无法隐藏在后台发生的内存传输。

因此,直觉应该是:要计算受限,请尽可能地使块更大!有两个主要约束

  1. VMEM 使用:我们的块越大,我们使用的 VMEM 就越多。使用足够大的块,我们将耗尽。

  2. 流水线气泡:我们的块相对于矩阵大小越大,我们流水线中的循环迭代次数就越少。这将使流水线开始和结束时的气泡大小相对于整个流水线更大,并且这种开销可能很重要。

在 Pallas 中获得良好的矩阵乘法性能归结为选择好的块大小来平衡这个优化问题。在实践中,我们经常扫描大量候选块大小,分析内核,然后选择最佳的一个。

现在,让我们做一些非常简单的计时实验。我们将使用 timeit 来测量运行每个内核所花费的时间。请注意,这只是内核实际运行时间的上限,因为我们正在使用 timeit 测量 Python 调度和其他开销。我们将计算以这种方式获得的 FLOP/s,并计算与芯片提供的性能相比我们获得的利用率百分比,我们将使用一些合理的块大小来验证我们的直觉。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的块大小有很大帮助!我们在更大的矩阵乘法中获得了相当好的利用率(80-90%),但最小的矩阵乘法似乎很难获得良好的性能。

让我们将其与 XLA 的矩阵乘法进行比较。我们不期望 Pallas 比 XLA 做得更好,因为 XLA 在生成矩阵乘法方面非常出色,但希望我们能接近。通过更仔细的块大小调整(留作未来工作),我们也可以达到 XLA 的性能。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas 通过一些非常基本的调整,非常接近 XLA 的性能数字!通过尝试更多的块大小,我们应该期望完全缩小差距。

模板化矩阵乘法#

现在我们有了一个基本的矩阵乘法内核,我们现在可以尝试将操作融合到其中。

融合的右侧转置#

通常要做的第一件事是融合转置。那是什么意思?假设我们想计算 x @ y.T 而不是 x @ y。天真地,我们可以先计算 y.T,然后将其传递到我们高效的矩阵乘法内核中。然而,操作 y.T 本身并非免费的 – 它涉及复制 O(n^2) 数据。理想情况下,我们可以在执行矩阵乘法仅在一个内核中计算转置,即将其与矩阵乘法“融合”。

加速器通常支持融合 RHS 转置的原生矩阵乘法例程。例如,TPU v5e,MXU 允许我们对小数组执行 x @ y.T。我们可以使用 jax.lax.dot_general 调用此例程,这将比单独进行转置然后进行矩阵乘法更有效。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general expects a data structure (contraction_dims, batch_dims),
  # where contraction_dims are the set of dimensions for LHS and RHS that will
  # be contracted (reduced) in the matmul; batch_dims, on the other hand, are
  # looped over. The remaining dimensions will be the input and output dimension
  # of the matmul.
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我们在 matmul 函数中进行转置(y = y.swapaxes(0, 1))。这是因为在 JIT 化的 JAX 计算内部,维度顺序纯粹是逻辑的,而不是物理的,因此重新排列维度并不意味着物理布局差异。但是,当我们将数组传递给 pallas_call 时,我们确实会强制执行主到次的维度排序约束。通过在 matmul 函数内部转置 y,我们要求 y 采用转置布局 (n, k) 而不是通常的 (k, n)。但是,用户仍然会以(逻辑)(k, n) 维度传递数组。

注意:为了对转置进行基准测试,我们实际上希望 y 在我们将其传递到内核时采用物理转置布局,这样我们就不会测量重新布局时间。在包装函数中,我们将在传递给 matmul 之前(逻辑上)将其转置回 (k, n),因为 matmul 期望逻辑 (k, n) 维度排序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看看我们如何获得相同的利用率,尽管有额外的转置!

融合的激活函数#

融合激活也很常见。这可以确保我们不会在高效的、计算受限的矩阵乘法内核之后执行缓慢的内存受限的激活内核。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

额外的融合激活几乎不会影响我们的利用率!

结论#

在本指南中,我们介绍了如何使用 Pallas 在 TPU 上编写高效的矩阵乘法。我们讨论了分块矩阵乘法和流水线、如何分析 TPU 矩阵乘法的性能以及如何编写高效的 bf16 矩阵乘法。我们最后介绍了对矩阵乘法进行模板化以支持融合的转置和融合的激活函数。

留给读者的练习

  • 添加对输入融合的支持。有时我们希望将操作融合到矩阵乘法的输入中。尝试对矩阵乘法进行更多模板化以支持此功能。

  • 添加对 int8 矩阵乘法的支持。TPU v5 支持原生 int8 矩阵乘法,其 FLOP 是 bf16 的两倍。尝试添加对此功能的支持,看看可以实现什么利用率。

  • matmul 函数添加向后传递支持。您可以使用 jax.custom_vjp 来完成此操作。