shmap
(shard_map
) 用于简单的逐设备代码#
sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@
2023 年 1 月
这是提出 shard_map
的设计文档。您可能需要查看最新的用户文档。
动机#
JAX 支持两种用于多设备编程的思想流派
编译器,接管! 让编译器自动将批量数组函数分区到设备上。
让我写出我想要表达的,该死的! 给我逐设备代码和显式的通信集合。
我们需要为两者都提供强大的 API,而且它们不应该是互斥的替代方案,而是需要彼此组合。
使用 pjit
(现在只是 jit
),我们为第一种模式提供了下一代 API。但是我们还没有完全提升第二种模式。pmap
遵循第二种模式,但随着时间的推移,我们发现它存在致命的缺陷。xmap
解决了这些缺陷,但它并没有给我们提供每个设备的形状,而且它还包含了其他几个重要的想法。与此同时,对每个设备的显式集体编程的新需求也出现了,例如在高效扩展 Transformer 推理中。
我们可以使用 shmap
来提升第二种模式。shmap
是
一个简单的多设备并行 API,它允许我们使用显式集合编写每个设备的代码,其中逻辑形状与每个设备的物理缓冲区形状匹配,并且集合与跨设备通信完全对应;
一个具有精简功能和一些调整的
xmap
的专业化版本;XLA SPMD 分区器的“手动”模式的相当直接的表面化;
一个有趣的苏斯式名称,可以代表
shard_map
、shpecialized_xmap
、sholto_map
或sharad_map
。
对于 pjit
用户,shmap
是一个补充工具。它可以在 pjit
计算内部使用,以临时进入“手动集合”模式,就像编译器自动分区的逃生舱一样。这样,用户可以在大部分代码中使用 pjit
的便利性和熟悉的仅 NumPy 编程模型,同时可以在需要的地方使用 shmap
手动优化集体通信。这是两全其美!
对于 pmap
用户,shmap
是一个严格的升级。它更具表达力、性能更高,并且可以与其他 JAX API 组合,而不会使基本批处理数据并行变得更困难。
有关实际用途的更多信息,您可以跳转到何时应该使用 shmap
,何时应该使用 pjit
?。如果您想知道为什么我们需要一个新东西,或者 pmap
有什么问题,请跳转到为什么 pmap
或 xmap
不能解决这个问题?。或者继续阅读下一节,了解一些 shmap
示例和 API 规范。
所以,让我们看看 shmap
!#
TL;DR 示例(后面会有更详细的解释)#
Sho shick
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((4, 2), ('i', 'j'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'j')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
注意
与
pmap
不同,多个并行轴不需要嵌套(或axis_index_groups
);与
pmap
和 hard-xmap
不同,调用者中不需要 reshape,并且逻辑形状对应于每个设备的物理形状,与(非 hard)xmap
不同;通过使用
mesh
精确控制设备放置,与pmap
不同;与
xmap
不同,逻辑和物理只有一个轴名称集;结果是一个
jax.Array
,可以高效地传递给pjit
,与pmap
不同;与
pmap
不同,这段相同的代码可以在pjit
/jit
内部高效工作;这段代码可以立即执行,因此我们可以在中间使用
pdb
并打印值,与xmap
的当前实现不同(但根据设计,没有顺序调度的xmap
原则上也可以立即执行)。
这是另一个具有完全分片结果的 matmul 变体
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
慢下来,从基础开始!#
数组轴上的降秩映射与保秩映射#
我们可以将 pmap
(以及 vmap
和 xmap
)视为沿轴拆分每个数组输入(例如,将 2D 矩阵解包为 1D 行),将其主体函数应用于每个部分,并将结果堆叠在一起,至少在不涉及集合时是这样
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
例如,如果 xs
的形状为 f32[8,5]
,则每个 x
的形状为 f32[5]
,并且如果每个 f(x)
的形状为 f32[3,7]
,则最终堆叠结果 pmap(f)(xs)
的形状为 f32[8,3,7]
。也就是说,主体函数 f
的每次应用都会以比 pmap(f)
的相应参数少一个轴的输入作为参数。我们可以说这些是降秩映射,具有输入/输出的拆分/堆叠。
f
的逻辑应用程序的数量由正在映射的输入轴的大小决定:例如,如果我们映射一个大小为 8 的输入轴,则在语义上,我们得到 8 个逻辑函数应用程序,对于 pmap,这些应用程序始终对应于 8 个物理计算它们的设备。
相反,shmap
没有这种降秩行为。相反,我们可以将其视为沿输入轴将切片(或“取消连接”)到块中,应用主体函数,然后将结果连接在一起(同样,当不涉及集合时)
devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
回想一下,jnp.split
将其输入切片为大小相同的块,这些块具有相同的秩,因此在上面的示例中,如果 y
的形状为 f32[8,5]
,则每个 y_blk
的形状为 f32[2,5]
,并且如果每个 f(y_blk)
的形状为 f32[3,7]
,则最终连接的结果 shard_map(f, ...)(y)
的形状为 f32[12,7]
。因此,shmap
(shard_map
) 会映射其输入的分片或块。我们可以说它是一个保秩映射,具有输入/输出的取消连接/连接。
f
的逻辑应用程序的数量由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即,在 4 个设备上),则在语义上,我们得到 4 个逻辑函数应用程序,对应于 4 个物理计算它们的设备。
使用 in_specs
控制每个输入如何拆分(取消连接)和切片#
每个 in_specs
都使用 PartitionSpec
按名称标识一些相应的输入数组的轴与网格轴,表示如何将输入拆分(或取消连接)为应用主体函数的块。该标识确定分片大小;当输入轴被标识为网格轴时,输入会沿该逻辑轴拆分(取消连接)为数量等于相应网格轴大小的片段。(如果相应的网格轴大小不能均匀地除输入数组轴大小,则会发生错误。)如果输入的 pspec 没有提及网格轴名称,则不会在该网格轴上进行拆分。例如
devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape)
return x_block
x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1) # prints (3,12)
在这里,由于输入 pspec 没有提及网格轴名称 'j'
,因此没有输入数组轴在该网格轴上拆分;同样,由于输入数组的第二个轴没有被识别为(因此没有在该轴上拆分)任何网格轴,因此 f1
的应用程序沿着该轴获得了输入的完整视图。
当输入 pspec 中未提及网格轴时,我们始终可以重写为一个效率较低的程序,其中提及了所有网格轴,但调用者执行 jnp.tile
,例如
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
换句话说,因为每个输入 pspec 可以零次或一次提及每个网格轴名称,而不是必须恰好提及每个名称一次,我们可以说,除了其输入中内置的 jnp.split
之外,shard_map
在其输入中还内置了 jnp.tile
,至少在逻辑上是这样(尽管平铺可能不需要物理执行,具体取决于参数的物理分片布局)。使用的平铺不是唯一的;我们也可以沿着第一个轴进行平铺,并使用 pspec P(('j', 'i'), None)
。
输入端可能发生物理数据移动,因为每个设备都需要拥有相应数据的副本。
使用 out_specs
控制每个输出如何通过串联、块转置和解平铺组装#
与输入端类似,每个 out_specs
通过名称将相应的输出数组的一些轴与网格轴标识,表示如何将输出块(对于函数体的每次应用,或等效地对于每个物理设备)组装在一起以形成最终输出值。例如,在上面的 f1
和 f2
示例中,out_specs
表示我们应该通过沿着两个轴串联块结果来形成最终输出,在两种情况下都得到一个形状为 (12,24)
的数组 y
。(如果函数体的输出形状,即输出块形状,对于相应的输出 pspec 描述的串联而言秩太小,则会出错。)
当输出 pspec 中未提及网格轴名称时,它表示解平铺:当用户编写一个未提及某个网格轴名称的输出 pspec 时,他们承诺沿着该网格轴输出块是相等的,因此输出中仅使用沿着该轴的一个块(而不是沿着该网格轴将所有块串联在一起)。例如,使用与上面相同的网格
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
请注意,闭包一个数组值的函数体等效于传递它作为具有 P(None, None)
的相应输入 pspec 的参数。作为另一个例子,更接近上面的其他例子
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) # (12,6)
请注意,结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum
,通过在输出 pspec 中不提及网格轴名称 'j'
来表示的解平铺是安全的,这确保了每个输出块沿着相应的网格轴是相等的。以下是另外两个示例,其中我们更改了在输出 pspec 中提及的网格轴
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
在物理端,在输出 pspec 中不提及网格轴名称会从具有沿着该网格轴复制布局的输出设备缓冲区组装一个 Array
。
没有运行时检查输出块是否确实沿着要解平铺的网格轴相等,或者等效地,相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是,我们可以提供一个静态检查机制,该机制会在所有可能不正确的程序上引发错误。
因为 out_specs
可以零次或一次提及网格轴名称,并且因为它们可以按任何顺序提及,我们可以说,除了其输出中内置的 jnp.concatenate
之外,shard_map
在其输出中还内置了解平铺和块转置。
无论输出 pspec 如何,输出端都不可能发生物理数据移动。相反,out_specs
仅编码如何将块输出组装成 Array
,或者物理上如何将跨设备的缓冲区解释为单个逻辑 Array
的物理布局。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
其中
mesh
编码排列在数组中并具有关联轴名称的设备,就像它对xmap
和sharding.NamedSharding
所做的那样;in_specs
和out_specs
是PartitionSpec
,可以仿射提及来自mesh
的轴名称(而不是像xmap
中那样单独的逻辑名称),以分别表达输入和输出的切片/解串联和串联(而不是像pmap
和xmap
所做的那样解堆叠和堆叠),其中未提及的名称分别对应于复制和解平铺(断言已复制,因此给我一份副本);传递给
f
的参数的形状与传递给shard_map
-of-f
的参数的形状具有相同的秩(与pmap
和xmap
不同,后者的秩会减小),并且传递给f
的参数的形状是从shard_map
-of-f
的相应参数的形状shape
和相应的PartitionSpec
规范计算得出的,大致为tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
;f
的主体可以使用来自mesh
的名称应用集合。
shmap
默认是急切的,这意味着我们逐个分派计算基元,以便用户可以在完全复制的值上采用 Python 控制流,并使用交互式 pdb
调试来打印任何值。要分阶段执行并端到端编译一个 shmap
映射的函数,只需在其周围放置一个 jit
。一个结果是 shmap
没有像 xmap
和 pmap
当前那样自己的分派和编译路径;它只是 jit
路径。
当它由例如封闭的 jit
分阶段执行时,将 shmap
降级为 StableHLO 是微不足道的:它只涉及在输入端切换到“手动 SPMD 模式”,并在输出端切换回来。(我们目前不打算支持部分手动、部分自动模式。)
与效果的交互与 pmap
相同。
与自动微分的交互也像 pmap
一样(而不是尝试 xmap
所做的新语义,对应于具有未映射的中间值,因此 grad
的 reduce_axes
,以及使 psum
转置为 pbroadcast
而不是 psum
)。但是,它因此继承了 pmap
的一个未解决的问题:在某些情况下,与其将 psum
转置为 psum
,从而执行与前向传递 psum
相对应的后向传递 psum
,不如将后向传递 psum
移动到后向传递中的其他位置,从而利用线性。许多高级 pmap
用户通过使用 custom_vjp
来实现 psum_idrev
和 id_psumrev
函数来解决此挑战,但是由于很容易意外地使它们不平衡,因此该技术是一种危险的做法。我们有一些关于如何以更安全的方式提供此功能的方法。
何时应该使用 shmap
,何时应该使用 pjit
?#
一种理念是:使用 jit==pjit
编写程序几乎总是更简单——但是如果程序的给定部分被编译器优化的程度低于它应该达到的程度,则可以使用 shmap
!
一个真实的示例#
以下是 shmap
在具有 2D 权重收集模式的 Transformer 层传递中的样子(论文,第 5 页的 3.2.3 节)
def matmul_2D_wg_manual(xnorm, q_wi, layer):
'''Calls a custom manual implementation of matmul_reducescatter'''
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
# -> (matmul)
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
# -> (reducescatter over x into X heads, B batches)
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
with jax.named_scope('q_wi'):
xnorm = intermediate_dtype(xnorm)
q_wi = matmul_reducescatter(
'bte,hed->bthd',
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='i',
layer=layer)
return q_wi
import partitioning.logical_to_physical as l2phys
def pjit_transformer_layer(
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Forward pass through a single layer, returning output, K, V."""
def my_layer(t, axis=0):
"""Gets the parameters corresponding to a given layer."""
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
# 2D: [batch.Z, time, embed.XY]
x = _with_sharding_constraint(
x, ('residual_batch', 'residual_time', 'residual_embed'))
xnorm = _layernorm(x)
# 2D: [batch, time, embed.X]
xnorm = _with_sharding_constraint(
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
# jump into manual mode where you want to optimise
if manual:
q_wi = shard_map(matmul_2D_wg_manual, mesh
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
else:
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
# 2D: [batch, time, heads.YZX, None]
q_wi = _with_sharding_constraint(q_wi,
('post_norm_batch', 'time', 'heads', 'qkv'))
q = q_wi[:, :, :, :hparams.qkv]
q = _rope(sin, cos, q)
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
# swiGLU with full d_ff dimension, rather than 2/3 scaled
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
k = kv[:, :, 0, :hparams.qkv]
v = kv[:, :, 0, hparams.qkv:]
k = _rope(sin, cos, k)
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
y_mlp = special2.swish2(wi0) * wi1
# 2D: [batch, time, heads.YZX, None]
y_mlp = _with_sharding_constraint(y_mlp,
('post_norm_batch', 'time', 'heads', None))
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
# do the second half of the mlp and the self-attn projection in parallel
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
# 2D: [batch.Z, time, embed.XY]
y_out = _with_sharding_constraint(
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
z = y_out + x
z = _with_sharding_constraint(
z, ('residual_batch', 'residual_time', 'residual_embed'))
return z, k, v
在下面的配置文件中,第一个和第二个 matmul 都被手动降级的版本所取代,其中计算(融合)与通信(ppermute)完全重叠!一个有趣的提示表明我们正在使用延迟优化的变体,即 ppmerute 像素是抖动的——因为同时有两个使用相反的 ICI 轴的重叠 ppermute!
全对全重叠要困难得多,因此被搁置了。
为什么 pmap
或 xmap
不能解决这个问题?#
pmap
是我们的第一个多设备并行 API。它遵循按设备代码和显式集合的模式。但是它存在一些重大缺陷,使其不适合当今的程序
映射多个轴需要嵌套的
pmap
。 嵌套的pmap
不仅编写起来很麻烦,而且还使得难以控制(甚至预测)数据和计算的设备放置,并且难以保留数据分片(请参见以下两点)。当今的程序需要多个并行轴。无法控制设备放置。 特别是在存在多个并行轴的情况下,程序员需要控制这些轴如何与硬件资源及其通信拓扑结构对齐。但是,(嵌套的)
pmap
无法控制映射的程序实例如何放置在硬件上;只有一个自动的设备顺序,用户无法控制。(Gopher 使用axis_index_groups
和单个非嵌套的pmap
本质上是一种 hack,通过将多个并行轴扁平化为一个轴来解决这个问题。)jit
/pjit
的可组合性。jit
-of-pmap
是一种性能陷阱,嵌套pmap
也是如此,例如scan
-of-pmap
,因为当从内部pmap
返回时,分片不会被保留。为了保留分片,我们需要对 jaxprs 进行模式匹配,以确保我们正在使用完美嵌套的 pmaps,或者只是在jit
内使用 pmap。此外,pjit
在这里也帮不上忙,因为pmap
目标是 XLA 副本,而pjit
目标是 XLA SPMD 分区器,将这两者组合起来很困难。jax.Array
兼容性(因此也是pjit
兼容性)。 因为pmap
输出的分片无法表示为Shardings
/OpShardings
,由于pmap
的堆叠而非连接语义,pmap
计算的输出当前无法传递给pjit
计算,而无需先返回到主机(或调度一个重塑计算)。多控制器语义(因此也是
pjit
兼容性)。 多控制器pmap
将值跨控制器连接起来,这种方式效果很好,但与单控制器pmap
的堆叠语义不同。更实际的是,它阻止了使用非完全可寻址的jax.Array
输入和输出,就像我们在多控制器pjit
中使用的那样。Eager 模式。 我们没有将
pmap
设计为 eager 优先,虽然我们最终(在 4 年多之后!)通过disable_jit()
添加了 eager 操作,但pmap
具有融合到其中的jit
意味着它有自己的编译和调度路径(实际上是两条调度路径:一条在 Python 中用于处理Tracer
,另一条在 C++ 中用于在原始Array
输入上获得性能!),这是一个沉重的实现负担。需要在调用者中进行重塑。 一个典型的在 8 个设备上使用
pmap
的用例可能是从大小为 128 的批次轴开始,将其重塑为分割成两个大小分别为 (8, 16) 的轴,然后对第一个轴进行pmap
操作。这些重塑很笨拙,并且编译器通常将它们解释为复制而不是视图 — 增加了内存和时间使用。
当只进行批次数据并行时,这些缺点并不算太糟糕。但是当涉及到更多并行时,pmap
就无法胜任了!
xmap
作为 pmap
的下一代演进铺平了道路,并解决了(几乎)所有这些问题。shmap
紧随 xmap
的脚步,并以基本相同的方式解决了这些问题;事实上,shmap
就像 xmap
的一个专门子集(有些人称之为“硬 xmap
” 子集),进行了一些调整。
对于初始原型,我们选择将 shmap
作为与 xmap
分开的原始类型来实现,因为限制其支持的功能集可以更容易地专注于核心功能。例如,shmap
不允许未映射的中间体,这使得不必担心命名轴和自动微分之间的相互作用。此外,不必推理所有特征对之间的相互作用,使得更容易添加超出今天在 xmap
中实现的功能,例如对 eager 模式的支持。
shmap
和 xmap
都共享大部分的降低代码。我们将来可以考虑将两者合并,甚至只关注 shmap
,这取决于用法将如何演变。