矩阵乘法#
在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将讨论如何思考 TPU 上的矩阵乘法性能以及如何模板化矩阵乘法内核以融合操作。
#@title Imports
import functools
from typing import Callable
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np
背景#
矩阵乘法是现代深度学习和语言建模核心的一种基本线性代数运算。我们希望使用 TPU 和 GPU 等专用加速器使矩阵乘法尽可能快,两者都具有专门的单元来快速执行矩阵乘法。
为了有效地利用 TPU 进行矩阵乘法,我们需要了解一些背景知识:分块矩阵乘法、平铺和流水线。
分块矩阵乘法#
假设我们想要实现一个 matmul(x, y)
函数,它可以通用地将一个 (m, k)
矩阵乘以一个 (k, n)
矩阵,但需要一个小小的变化。我们只能使用一个名为 matmul_small
的基本函数,它只能够乘以较小的矩阵(例如,m, k, n <= 256
)。我们该如何实现呢?
矩阵乘法的一个很好的性质是,输出矩阵的每个块都可以表示为输入矩阵的行块和列块进行若干次较小矩阵乘法的累加结果。更正式地说,如果我们有两个输入矩阵 \(x \in \mathbb{R}^{m \times k}\) 和 \(y \in \mathbb{R}^{k \times n}\),输出矩阵为 \(z \in \mathbb{R}^{m \times n}\),我们将它们沿着各个维度分解成大小为 \(b_m, b_k, b_n\) 的块。
例如,\(x\) 可以分解成如下形式:
其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们也可以用类似的方法分解 \(y\) 和 \(z\))。
对于特定的输出块 \(z_{ij}\),我们可以将其计算为:
因此,每个输出块 \(z_{ij}\) 都是若干个较小的块矩阵乘法 \(x_{ik} y_{kj}\) 的累加结果。以下是用 NumPy 实现该算法的方法:
def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
m, k, n = x.shape[0], x.shape[1], y.shape[0]
assert m <= 256
assert k <= 256
assert n <= 256
return np.matmul(x, y)
def block_matmul(
x: np.ndarray,
y: np.ndarray,
*,
bm: int = 256,
bk: int = 256,
bn: int = 256,
) -> np.ndarray:
m, k = x.shape
_, n = y.shape
z = np.zeros((m, n), dtype=x.dtype)
for m_i in range(m // bm):
for n_i in range(n // bn):
for k_i in range(k // bk):
m_slice = slice(m_i * bm, (m_i + 1) * bm)
k_slice = slice(k_i * bk, (k_i + 1) * bk)
n_slice = slice(n_i * bn, (n_i + 1) * bn)
x_block = x[m_slice, k_slice]
y_block = y[k_slice, n_slice]
z[m_slice, n_slice] += matmul_small(x_block, y_block)
return z
现在我们的 block_matmul
函数应该能够处理大于 256 的输入(尽管我们假设输入维度可以被 256 整除)。
m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)
block_matmul
函数将矩阵乘法分解成许多较小的矩阵乘法,方法是观察到,每个大小为 (bm, bn)
的输出块都可以通过累加若干个 (bm, bk) x (bk, bn)
大小的矩阵乘法结果来计算。
TPU 和 GPU 也以类似的方式进行矩阵乘法!它们原生支持类似于 matmul_small
的小矩阵乘法,因此,为了在进行更大的矩阵乘法时利用这种硬件,我们将应用 block_matmul
分解方法。
分块和流水线#
在 之前的指南 中,我们介绍了如何在 Pallas 中对计算进行分块和流水线操作。为了确保我们的计算单元始终处于工作状态,并且不会因为内存传输而停滞,我们将下一个内核迭代的内存传输与当前迭代的内存传输重叠起来。
在 Pallas 中,我们可以通过 BlockSpec
和 grid
来指定这一点。请注意,我们在块矩阵乘法算法中已经有一个嵌套的 for 循环。我们可以通过 grid
在 Pallas 中指定这一点。块矩阵乘法中的切片也可以通过 BlockSpec
来指定。
你的第一个矩阵乘法内核#
将所有这些内容整合在一起,下面是一个块矩阵乘法内核的实现,它对内存传输和计算进行流水线操作。我们创建了一个 3 维网格,对应于 NumPy 代码中的 3 层嵌套循环。请注意,虽然 MXU 只能够乘以小块矩阵,但 Pallas 会自动将更大的块矩阵拆分并在 MXU 上进行分块操作。
网格的最后一维对应于矩阵乘法的压缩维度,它是一个归约维度,因此,我们需要确保将累加器初始化。
def matmul_kernel(x_ref, y_ref, z_ref):
@pl.when(pl.program_id(2) == 0)
def _():
z_ref[...] = jnp.zeros_like(z_ref)
z_ref[...] += x_ref[...] @ y_ref[...]
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
):
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
grid=(m // bm, n // bn, k // bk),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))
矩阵乘法的性能#
让我们思考一下如何分析矩阵乘法的性能。当我们考虑矩阵乘法的性能时,通常会关注两方面:总的浮点运算次数 (FLOPs) 和内存带宽使用量。从 关于 TPU 和流水线的指南 中,我们可以看到,为了有效地使用 TPU 上的计算单元(以及一般情况下 ML 加速器上的计算单元),我们需要将输入从 HBM 复制到 VMEM,即更靠近计算单元的位置。这种从 HBM 到 VMEM 的复制操作需要时间,一个高效的内核应该将大部分时间花在实际计算上,而不是等待这些传输。内存带宽衡量的是这种数据传输的速率。
快速说明:在本指南中,我们将讨论浮点运算,但需要区分 FLOPs 和 FLOP/s。当我们说“FLOPs”时,指的是“浮点运算次数”,即运算次数。当我们说“FLOP/s”时,指的是“每秒的浮点运算次数”,即执行浮点运算的速率。
一个 (m, k) x (k, n)
矩阵乘法的 FLOPs 大致为 2 * m * k * n
。(严格来说,它是 n * m * (2k - 1)
,但对于足够大的 k
,我们的近似值就足够了)。
矩阵乘法的最小内存带宽使用量(假设为 float32)是输入总大小(复制到 VMEM)加上输出大小(复制到 HBM)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 bytes/float32
。如果我们多次重新读取输入,内存使用量可能会更大,这种情况很常见。
一个观察结果是,矩阵乘法的 FLOPs 与其输入呈立方关系,而最小带宽使用量与其输入呈二次关系。直观地说,这意味着 FLOPs 的增长速度快于带宽使用量,也就是说,矩阵乘法越大,其计算量相对于复制操作的比例就越大。
def matmul_flops(m: int, k: int, n: int):
return 2 * m * k * n
def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
return (m * k + k * n + m * n) * np.dtype(dtype).itemsize
print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912
现在,我们可以计算矩阵乘法的总 FLOPs 和(最小)内存带宽使用量,让我们看看真实的 TPU 可以处理多少。
本笔记本是在 TPU v5e 芯片上运行的,因此我们将使用 v5e 的数据(如果您正在运行本笔记本,您的数据可能会有所不同)。TPU v5e 具有 197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这两个数字的比率(称为算术强度),我们可以获得一个边界值,该边界值表明,当这个“FLOPs / 内存带宽使用量”比率低于该值时,我们就将变得 I/O 受限(在 TPU v5e 上大约为 240 FLOPs/字节)。
v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw # ~240.5
粗略地说,这些数字告诉我们,矩阵乘法的 FLOPs 大致需要 2 * m * k * n / (197 TFLOP/s)
秒,而从 VMEM 到 HBM 的复制操作需要 (m*k + k*n + m*n) * 4 bytes / 819GB/s
秒。
def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
flops = matmul_flops(m, k, n)
membw = matmul_membw(m, k, n, dtype)
return flops / membw
这个基本计算告诉我们,我们能够以多高的效率使用 MXU。如果矩阵乘法的操作强度低于我们芯片的性能,那么我们的计算将是 *内存受限* 的,也就是说,我们的计算单元将在等待数据传输时处于空闲状态。如果矩阵乘法的操作强度高于芯片的性能,那么我们将是 *计算受限* 的。
因为矩阵乘法的 FLOPs 与其输入大小呈立方关系,而内存带宽使用量与其输入大小呈二次关系,所以我们预计,随着矩阵乘法变得越来越大,我们将变得计算受限,但这个交叉点非常重要!假设我们要进行一个 (1024, 1024) x (1024, 1024)
的 float32 矩阵乘法。
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte
我们的矩阵乘法 FLOPs 强度低于我们芯片的性能。这不好!在这种类型的矩阵乘法中,我们很可能会受到内存的限制。但是,如果我们的输入和输出更大呢?当我们的矩阵乘法足够大时,我们将从内存受限状态转变为计算受限状态。例如,如果我们有一个矩阵乘法,其中 m = k = n
,那么我们在 TPU v5e 上的交叉点将是 2m**3 / 12m**2 > 240
,即 m = k = n > 1440
。
bfloat16
矩阵乘法#
为了让矩阵乘法更容易在 TPU 上变得计算受限,我们可以使用更小的数据类型来表示输入和输出。我们之前的示例使用了 float32
类型的输入和输出,但 TPU v5e 也支持 bfloat16
数据类型(一种 16 位浮点格式,也称为 bf16
)来进行矩阵乘法。在 TPU v5e 上,我们将拥有相同的 FLOP/s,但 *内存带宽使用量将减半*。这使得较小的矩阵更容易变得计算受限。让我们看看 1024 x 1024 x 1024 的 bf16
矩阵乘法的强度。
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte
现在,我们得到了一个计算受限的矩阵乘法!
让我们在矩阵乘法内核中添加 bf16
支持。
原生 MXU bf16
矩阵乘法例程接受两个 bf16
矩阵作为输入,并将它们累加到 f32
中。我们将通过将 preferred_element_type=jnp.float32
传递到 jnp.matmul
来触发该例程。我们还需要一个 f32
类型的累加器 Ref
。然后,我们将输出向下转换为 bf16
,然后再将其写回到 HBM。这样,我们就不会损失任何精度,也不会进行任何额外的转换,并且仍然可以保留 bf16
内存带宽的节省。
请注意,目前分配 scratch space 的唯一方法是通过
pltpu.PrefetchScalarGridSpec
。现在不用担心它的具体作用,你只需要知道它允许你在 VMEM 中分配 scratch space。
def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] += jnp.dot(
x_ref[...], y_ref[...], preferred_element_type=jnp.float32
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...].astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
):
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(matmul_kernel, nsteps=k // bk),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))
流水线内核的性能#
我们上面关于 FLOPs 与内存使用量的分析适用于一个粗略的尺度,即当我们查看整个矩阵乘法的规模时。但是,请记住,在实践中,我们正在对阻塞矩阵乘法进行流水线执行,这意味着我们有一个循环,在这个循环中我们对更小的块进行矩阵乘法。
这意味着我们实际上关心的是每个内核实例的 FLOPs 与内存带宽使用量,而不是全局 FLOPs 与内存带宽使用量。因此,块大小 bm
、bk
和 bn
对性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bm
、bk
和 bn
,我们将会受到内存的限制,因为每次调用内核时,我们的 FLOPs 太少,无法隐藏后台发生的内存传输。
因此,直觉应该是:要成为计算受限的,尽量使块大小尽可能大!有两个主要限制
VMEM 使用量:块越大,我们使用的 VMEM 就越多。如果块足够大,我们将耗尽 VMEM。
流水线气泡:块的大小相对于矩阵大小越大,流水线中的循环迭代次数就越少。这将使流水线开头和结尾的气泡大小相对于整个流水线更大,而这种开销可能是非微不足道的。
在 Pallas 中获得良好的矩阵乘法性能归结为选择良好的块大小来平衡此优化问题。在实践中,我们通常会遍历一组大量的候选块大小,对内核进行性能分析,并选择最佳的块大小。
现在,让我们进行一些非常简单的计时实验。我们将使用 timeit
来衡量运行每个内核所需的时间。请注意,这是内核实际运行时间的上限,因为我们使用 timeit
来衡量 Python 分派和其他开销。我们将计算以这种方式获得的 FLOP/s 数量,并计算与芯片提供的性能相比获得的利用率百分比,并将使用一些合理的块大小来验证我们的直觉。
import timeit
def benchmark(f, ntrials: int = 100):
def run(*args, **kwargs):
# Compile function first
jax.block_until_ready(f(*args, **kwargs))
# Time function
result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
number=ntrials)
time = result / ntrials
# print(f"Time: {time}")
return time
return run
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func):
x = jnp.ones((m, k), dtype=dtype)
y = jnp.ones((k, n), dtype=dtype)
time = benchmark(mm_func)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00029766598949208854
Matmul FLOP/s: 7214407167121.377
FLOP/s utilization: 3.6621%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.011771515250438824
Matmul FLOP/s: 11675553278230.387
FLOP/s utilization: 5.9267%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09183577066054567
Matmul FLOP/s: 11972585626140.668
FLOP/s utilization: 6.0775%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012708659982308746
Matmul FLOP/s: 16897797651282.135
FLOP/s utilization: 8.5776%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.00088908776990138
Matmul FLOP/s: 154584235803001.88
FLOP/s utilization: 78.4692%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006099433819763363
Matmul FLOP/s: 180264539343531.62
FLOP/s utilization: 91.5048%
更大的块大小帮助很大!我们在较大的矩阵乘法中获得了相当好的利用率(80-90%),但最小的矩阵乘法似乎很难获得良好的性能。
让我们将其与 XLA 的矩阵乘法进行比较。我们不希望 Pallas 的性能超过 XLA,因为 XLA 在生成矩阵乘法方面非常出色,但希望我们能够接近。通过更仔细地调整块大小(留作今后的工作),我们也可以达到 XLA 的性能。
print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00011943008983507753
Matmul FLOP/s: 17981093801113.996
FLOP/s utilization: 9.1275%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.0008272899803705514
Matmul FLOP/s: 166131533963991.34
FLOP/s utilization: 84.3307%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006047147869830951
Matmul FLOP/s: 181823175395037.44
FLOP/s utilization: 92.2960%
Pallas 在进行了一些非常基本的调整后,已经非常接近 XLA 的性能数字!通过尝试更多块大小,我们应该可以完全缩小差距。
矩阵乘法的模板化#
现在我们有了基本的矩阵乘法内核,我们可以尝试将操作融合到其中。
融合的右手侧转置#
要做的第一件事通常是融合转置。我们指的是什么呢?假设我们想要计算 x @ y.T
而不是 x @ y
。我们可以先计算 y.T
,然后将其传递给我们的高效矩阵乘法内核。但是,操作 y.T
本身并不免费 - 它涉及复制 O(n^2)
数据。理想情况下,我们可以在同一个内核中进行矩阵乘法时计算转置,即将其与矩阵乘法“融合”。
加速器通常支持融合 RHS 转置的本地矩阵乘法例程。例如,TPU v5e,MXU 允许我们对小数组进行 x @ y.T
。我们可以使用 jax.lax.dot_general
调用此例程,这将比单独进行转置和矩阵乘法更高效。
def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
# dot_general expects a data structure (contraction_dims, batch_dims),
# where contraction_dims are the set of dimensions for LHS and RHS that will
# be contracted (reduced) in the matmul; batch_dims, on the other hand, are
# looped over. The remaining dimensions will be the input and output dimension
# of the matmul.
if transpose_rhs:
dims = ((1,), (1,)), ((), ())
else:
dims = ((1,), (0,)), ((), ())
acc_ref[...] += jax.lax.dot_general(
x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...].astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
transpose_rhs: bool = False,
):
if transpose_rhs:
y = y.swapaxes(0, 1)
y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
else:
y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
y_block_spec,
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
我们在 matmul
函数内部进行转置 (y = y.swapaxes(0, 1)
)。这是因为在 JIT 的 JAX 计算内部,维度排序纯粹是逻辑上的,而不是物理上的,因此重新排列维度并不意味着物理布局上的差异。但是,当我们将数组传递给 pallas_call
时,我们确实会强制执行从主维度到次维度的维度排序约束。通过在 matmul
函数内部转置 y
,我们请求 y
采用转置布局 (n, k)
,而不是通常的 (k, n)
。但是,用户仍然会以 (逻辑) (n, k)
维度排序传递数组。
注意:为了对转置进行基准测试,我们实际上希望 y
在将其传递给内核时采用物理转置布局,因此我们不会衡量重新布局时间。在包装器函数中,我们将 (逻辑) 将其转置回 (n, k)
,然后再将其传递给 matmul
,因为 matmul
预期 (逻辑) (n, k)
维度排序。
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func, transpose_rhs: bool = False):
x = jnp.ones((m, k), dtype=dtype)
if transpose_rhs:
y = jnp.ones((n, k), dtype=dtype)
@jax.jit
def _wrapper(x, y):
y = y.swapaxes(0, 1)
return mm_func(x, y, transpose_rhs=True)
else:
y = jnp.ones((k, n), dtype=dtype)
_wrapper = mm_func
time = benchmark(_wrapper)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.0003029372810851783
Matmul FLOP/s: 7088872126624.065
FLOP/s utilization: 3.5984%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.012017967159627005
Matmul FLOP/s: 11436123235026.848
FLOP/s utilization: 5.8051%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09500920018996112
Matmul FLOP/s: 11572685861765.383
FLOP/s utilization: 5.8745%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012131539988331496
Matmul FLOP/s: 17701657415839.363
FLOP/s utilization: 8.9856%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.0008790623804088682
Matmul FLOP/s: 156347213275211.03
FLOP/s utilization: 79.3641%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006107717020204291
Matmul FLOP/s: 180020067095253.78
FLOP/s utilization: 91.3807%
看看我们如何获得相同的利用率,尽管有额外的转置!
融合的激活函数#
融合激活函数也非常常见。这确保我们不会在高效的、计算受限的矩阵乘法内核之后使用缓慢的、内存受限的激活内核。
def matmul_kernel(
x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
if transpose_rhs:
dims = ((1,), (1,)), ((), ())
else:
dims = ((1,), (0,)), ((), ())
acc_ref[...] += jax.lax.dot_general(
x_ref[...],
y_ref[...],
dims,
preferred_element_type=jnp.float32,
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
transpose_rhs: bool = False,
activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
if transpose_rhs:
y = y.swapaxes(0, 1)
y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
else:
y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(
matmul_kernel,
nsteps=k // bk,
transpose_rhs=transpose_rhs,
activation=activation,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
y_block_spec,
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func, transpose_rhs: bool = False,
activation = lambda x: x):
x = jnp.ones((m, k), dtype=dtype)
if transpose_rhs:
y = jnp.ones((n, k), dtype=dtype)
@jax.jit
def _wrapper(x, y):
y = y.swapaxes(0, 1)
return mm_func(x, y, transpose_rhs=True, activation=activation)
else:
y = jnp.ones((k, n), dtype=dtype)
_wrapper = functools.partial(mm_func, activation=activation)
time = benchmark(_wrapper)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00030103540048003196
Matmul FLOP/s: 7133658182976.541
FLOP/s utilization: 3.6211%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.011807117109419778
Matmul FLOP/s: 11640348122095.826
FLOP/s utilization: 5.9088%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09181861146935262
Matmul FLOP/s: 11974823079773.941
FLOP/s utilization: 6.0786%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012622540001757442
Matmul FLOP/s: 17013086492108.6
FLOP/s utilization: 8.6361%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.000896632740041241
Matmul FLOP/s: 153283442968721.44
FLOP/s utilization: 77.8089%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006130605939542875
Matmul FLOP/s: 179347953304919.88
FLOP/s utilization: 91.0396%
额外的融合激活几乎不会影响我们的利用率!
结论#
在本指南中,我们介绍了如何在 TPU 上使用 Pallas 编写高效的矩阵乘法。我们讨论了阻塞矩阵乘法和流水线,如何分析 TPU 矩阵乘法的性能,以及如何编写高效的 bf16
矩阵乘法。最后,我们对矩阵乘法进行模板化,以支持融合转置和融合激活函数。
留给读者的练习
添加对输入融合的支持。有时我们希望将操作融合到矩阵乘法的输入中。尝试对矩阵乘法进行更多模板化以支持此功能。
添加对
int8
矩阵乘法的支持。TPU v5 支持本地int8
矩阵乘法,其 FLOPs 是bf16
的两倍。尝试添加对它的支持,看看可以实现什么样的利用率。添加
matmul
函数的反向传递支持。你可以使用jax.custom_vjp
来完成此操作。