矩阵乘法#

在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将介绍如何在 TPU 上考虑矩阵乘法性能,以及如何模板化矩阵乘法内核以融合操作。

#@title Imports
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩阵乘法是现代深度学习和语言建模的核心线性代数运算。我们希望使用像 TPU 和 GPU 这样的专用加速器尽可能地加快矩阵乘法的速度,它们都具有用于快速矩阵乘法的专用单元。

为了有效地利用 TPU 进行矩阵乘法,我们需要了解一些背景概念:分块矩阵乘法、分片和流水线。

分块矩阵乘法#

假设我们要实现 matmul(x, y),它通常将一个 (m, k) 数组与一个 (k, n) 数组相乘,但有一个转折。我们只允许使用基本运算 matmul_small,它乘以小矩阵(例如 m, k, n <= 256)。我们该如何实现呢?

矩阵乘法的一个很好的特性是,输出的每个块都可以表示为输入行块和列块的几个较小的矩阵乘法之和。形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\) 以及输出 \(z \in \mathbb{R}^{m \times n}\),我们将它们沿着大小为 \(b_m, b_k, b_n\) 的维度分解为块。

例如,\(x\) 将分解为

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们可以类似地分解 \(y\)\(z\)。)

对于特定的输出块 \(z_{ij}\),我们可以将其计算为

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每个输出块 \(z_{ij}\) 都是几个较小的分块矩阵乘法 \(x_{ik} y_{kj}\) 的总和。以下是如何在 NumPy 中实现此算法的方法

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我们的 block_matmul 函数现在应该可以处理大于 256 的输入(尽管我们假设我们的输入维度可以被 256 整除)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 通过观察到大小为 (bm, bn) 的每个输出块可以通过累积几个 (bm, bk) x (bk, bn) 大小的矩阵乘法来计算,从而将矩阵乘法分解为多个较小的矩阵乘法。

TPU 和 GPU 执行矩阵乘法的方式就像这样!它们原生支持类似于 matmul_small 的小矩阵乘法,因此为了在执行较大的矩阵乘法时利用此硬件,我们将应用 block_matmul 分解。

分片和流水线#

之前的指南中,我们介绍了如何在 Pallas 中进行分片计算和流水线处理。为了确保我们的计算单元始终在工作,并且不会因内存传输而停顿,我们将内核的下一次迭代的内存传输与当前迭代的内存传输重叠。

在 Pallas 中,我们通过 BlockSpec 和一个 grid 来指定这一点。请注意,我们已经在分块矩阵乘法算法中有一个嵌套的 for 循环。我们可以通过 Pallas 中的 grid 来指定这一点。分块矩阵乘法中的切片也可以通过 BlockSpec 来指定。

您的第一个矩阵乘法内核#

将所有内容放在一起,下面是一个分块矩阵乘法内核的实现,该内核将内存传输与计算流水线化。我们创建一个 3 维网格,对应于 NumPy 代码中的 3 个嵌套循环。请注意,虽然 MXU 只能乘以小块,但 Pallas 会自动采用更大的块,并自动将它们分片到 MXU 上。

网格的最后一个维度对应于矩阵乘法的收缩维度,并且是一个归约维度,因此我们需要确保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩阵乘法性能#

让我们思考如何分析矩阵乘法的性能。当我们考虑矩阵乘法性能时,我们通常关心两件事:浮点运算 (FLOP) 的总数和内存带宽的使用量。从关于 TPU 和流水线的指南中,我们看到为了使用 TPU 上(以及通用 ML 加速器上)的高效计算单元,我们需要将输入从 HBM 复制到更接近计算单元的 VMEM 中。这种来回 HBM 的复制需要时间,高效的内核有望将其大部分时间用于实际计算,而不是等待这些传输。内存带宽衡量此数据传输的速率。

快速说明:在本指南中,我们将讨论浮点运算,但希望区分 FLOP 和 FLOP/s。当我们说“FLOP”时,我们的意思是“浮点运算”,如运算的数量。当我们说“FLOP/s”时,我们指的是“每秒浮点运算”,如执行浮点运算的速率

(m, k) x (k, n) 矩阵乘法中的 FLOP 数(近似)为 2 * m * k * n。(技术上,它是 n * m * (2k - 1),但对于足够大的 k,我们的近似值就足够了。)

矩阵乘法的最小内存带宽使用量(假设为 float32)是输入总大小(复制到 VMEM 中)加上输出大小(复制到 HBM 中)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 字节/float32。如果我们多次读取输入,则内存使用量可能会更大,这种情况通常是这样。

一个观察结果是,矩阵乘法 FLOP 的数量是其输入的立方,而最小带宽使用量是其输入的平方。直观地说,这意味着 FLOP 的增长速度快于带宽使用量,这意味着我们的矩阵乘法越大,我们相对于复制的计算就越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

既然我们可以计算矩阵乘法的 FLOP 总数和(最小)内存带宽使用量,让我们看看真正的 TPU 可以处理什么。

此笔记本在 TPU v5e 芯片上运行,因此我们将使用 v5e 的数字(如果您正在运行此笔记本,则您的数字可能会有所不同)。TPU v5e 具有 197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以获得一个下限,即在我们在 TPU v5e 上受到 IO 限制之前,此“FLOP/内存带宽使用量”的比率可以有多低(大约为 240 FLOP/字节)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

粗略地说,这些数字告诉我们矩阵乘法的 FLOP 应花费 2 * m * k * n / (197 TFLOP/s) 秒,而与 VMEM 的来回复制应花费 (m*k + k*n + m*n) * 4 字节 / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

此基本计算告诉我们大致如何有效地使用我们的 MXU。如果我们的矩阵乘法运算强度低于我们芯片的能力,那么我们的计算将受到内存限制,即我们的计算单元在等待值传输时将处于空闲状态。如果矩阵乘法强度高于芯片的能力,那么我们将受到计算限制

由于矩阵乘法 FLOP 是其输入大小的立方,而内存带宽使用量是平方,我们预计随着我们变得越来越大,我们将受到计算限制,但是这个交叉点非常重要!假设我们正在执行 (1024, 1024) x (1024, 1024) float32 矩阵乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我们的矩阵乘法 FLOP 强度低于我们芯片的能力。这不好!我们很可能受到这种矩阵乘法的内存限制。但是,如果我们的输入和输出更大呢?在某些时候,当我们的矩阵乘法变得足够大时,我们将从内存限制过渡到计算限制。例如,如果我们有一个矩阵乘法,其中 m = k = n,则当 2m**3 / 12m**2 > 240 或当 m = k = n > 1440 时,我们将在 (TPU v5e) 上过渡。

bfloat16 矩阵乘法#

为了使矩阵乘法在 TPU 上更容易受到计算限制,我们还可以为输入和输出使用较小的 dtype。我们之前的示例使用了 float32 输入和输出,但 TPU v5e 也支持 bfloat16 数据类型(一种 16 位浮点格式,也称为 bf16)用于矩阵乘法。在 TPU v5e 上,我们将拥有相同的 FLOP/s,但会将内存带宽使用量减半。这使得较小矩阵更容易受到计算限制。让我们看看一个 1024 x 1024 x 1024 的 bf16 矩阵乘法的强度是多少

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我们现在有一个受计算限制的矩阵乘法!

让我们将 bf16 支持添加到我们的矩阵乘法内核。

原生的 MXU bf16 矩阵乘法例程接受两个 bf16 输入矩阵并在 f32 中累积。我们将通过将 preferred_element_type=jnp.float32 传递到 jnp.matmul 来触发此例程。我们还需要一个 f32 中的累加器 Ref。然后,我们将在写回 HBM 之前将输出向下转换为 bf16。这样我们就不会丢失任何精度,不会进行任何额外的转换,并且仍然保留 bf16 的内存带宽节省。

请注意,目前分配临时空间的唯一方法是通过 pltpu.PrefetchScalarGridSpec。现在不用担心它具体做什么——现在你需要知道的是它允许你在 VMEM 中分配临时空间。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

流水线内核的性能#

我们上面关于 FLOP 与内存使用的分析适用于粗略的尺度,即当我们查看整个矩阵乘法的大小时。但是,请记住,在实践中,我们正在对分块矩阵乘法的执行进行流水线处理,这意味着我们有一个循环,其中我们正在使用较小的块进行矩阵乘法。

这意味着我们实际上关心的是内核的每个单独实例的 FLOP 与内存带宽使用情况,而不是全局 FLOP 与内存带宽使用情况。因此,块大小 bmbkbn 对于性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bmbkbn,我们将受到内存限制,因为每次调用内核时,我们拥有的 FLOP 太少,无法隐藏后台发生的内存传输。

因此,直觉应该是:要受计算限制,请使块尽可能大!有两个主要约束

  1. VMEM 使用:我们的块越大,我们使用的 VMEM 就越多。当块足够大时,我们将耗尽。

  2. 流水线气泡:我们的块相对于矩阵大小越大,我们在流水线中的循环迭代次数就越少。这将使流水线开始和结束时的气泡大小相对于整个流水线更大,并且这种开销可能不是微不足道的。

在 Pallas 中获得良好的矩阵乘法性能归结为选择良好的块大小以平衡此优化问题。在实践中,我们通常会扫描大量的候选块大小,分析内核,并选择最佳的块大小。

现在,让我们做一些非常简单的计时实验。我们将使用 timeit 来测量运行每个内核所花费的时间。请注意,这只是内核实际运行时间的上限,因为我们正在使用 timeit 来测量 Python 分派和其他开销。我们将以这种方式计算我们获得的 FLOP/s 量,并计算我们与芯片提供的相比所获得的利用率百分比,我们将使用一些合理的块大小来验证我们的直觉。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的块大小有很大帮助!我们在更大的矩阵乘法中获得了相当好的利用率(80-90%),但是最小的矩阵乘法似乎很难获得良好的性能。

让我们将其与 XLA 的矩阵乘法进行比较。我们不期望 Pallas 比 XLA 做得更好,因为 XLA 在生成矩阵乘法方面非常出色,但希望我们接近。通过更仔细的块大小调整(留作未来工作),我们也可以达到 XLA 的性能。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas 通过一些非常基本的调整,非常接近 XLA 的性能数字!通过尝试更多的块大小,我们应该可以完全消除差距。

模板化矩阵乘法#

现在我们有了一个基本的矩阵乘法内核,我们现在可以尝试将操作融合到其中。

融合的右侧转置#

通常要做的第一件事是融合转置。那是什么意思?假设我们想要计算 x @ y.T 而不是 x @ y。最简单的方法是先计算 y.T,然后将其传递到我们高效的矩阵乘法内核中。但是,操作 y.T 本身不是免费的——它涉及复制 O(n^2) 数据。理想情况下,我们可以执行矩阵乘法时计算转置,只需一个内核,即将其与矩阵乘法“融合”。

加速器通常支持融合 RHS 转置的原生矩阵乘法例程。例如,TPU v5e,MXU 允许我们对小数组执行 x @ y.T。我们可以使用 jax.lax.dot_general 调用此例程,这将比单独执行转置然后执行矩阵乘法更有效。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general expects a data structure (contraction_dims, batch_dims),
  # where contraction_dims are the set of dimensions for LHS and RHS that will
  # be contracted (reduced) in the matmul; batch_dims, on the other hand, are
  # looped over. The remaining dimensions will be the input and output dimension
  # of the matmul.
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我们在 matmul 函数中执行转置(y = y.swapaxes(0, 1))。这是因为在 JIT 编译的 JAX 计算中,维度排序纯粹是逻辑的,而不是物理的,因此重新排列维度并不意味着物理布局的差异。但是,当我们传递一个数组到 pallas_call 时,我们确实强制执行了主到次的维度排序约束。通过在 matmul 函数内部转置 y,我们请求 y 采用转置布局 (n, k) 而不是通常的 (k, n)。用户仍然会以(逻辑)(k, n) 维度传递数组。

注意:要对转置进行基准测试,我们实际上希望在将 y 传递到内核时,y 采用物理转置布局,这样我们才不会测量重新布局时间。在包装函数中,我们将在将其传递到 matmul 之前将其(逻辑上)转置回 (k, n),因为 matmul 需要逻辑 (k, n) 维度排序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看看尽管进行了额外的转置,我们如何获得相同的利用率!

融合的激活函数#

融合激活函数也很常见。这可以确保我们不会在高效的、受计算限制的矩阵乘法内核之后使用缓慢的、受内存限制的激活内核。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

额外的融合激活几乎不会影响我们的利用率!

结论#

在本指南中,我们介绍了如何使用 Pallas 在 TPU 上编写高效的矩阵乘法。我们讨论了分块矩阵乘法和流水线,如何分析 TPU 矩阵乘法的性能,以及如何编写高效的 bf16 矩阵乘法。最后,我们对矩阵乘法进行模板化,以支持融合的转置和融合的激活函数。

留给读者的练习

  • 添加对输入融合的支持。有时我们想将操作融合到矩阵乘法的输入中。尝试更多地模板化矩阵乘法以支持这一点。

  • 添加对 int8 矩阵乘法的支持。TPU v5 支持原生的 int8 矩阵乘法,其 FLOP 是 bf16 的两倍。尝试添加对它的支持,看看可以实现什么利用率。

  • 添加对 matmul 函数的反向传递支持。您可以使用 jax.custom_vjp 来完成此操作。