shmap
(shard_map
) 用于简单的每设备代码#
sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@
2023 年 1 月
动机#
JAX 支持两种关于多设备编程的想法
编译器,你来掌舵!让编译器自动将批量数组函数分配到各个设备上。
让我写出我的意思,该死的!给我每设备代码和显式通信集合。
我们需要针对这两者提供强大的 API,并且它们不应该是相互排斥的替代方案,而是需要相互组合。
有了 pjit
(现在只是 jit
),我们有了 一种下一代 API,用于第一种想法。但是我们还没有完全提升第二种想法。 pmap
遵循第二种想法,但随着时间的推移,我们发现它存在 致命缺陷。 xmap
解决了这些缺陷,但它没有完全给我们每设备形状,并且它还包括其他一些重大想法。与此同时,出现了对每设备显式集合编程的新需求,例如在 Efficiently Scaling Transformer Inference 中。
我们可以使用 shmap
提升第二所学校的水平。 shmap
是
一个简单的多设备并行 API,它允许我们编写具有显式集合的每个设备代码,其中逻辑形状匹配每个设备的物理缓冲区形状,集合完全对应于跨设备通信;
xmap
的一个特化,具有缩减的功能和一些调整;XLA SPMD Partitioner 的“手动”模式的相当直接的呈现;
一个朗朗上口的苏斯博士风格的名字,可以代表
shard_map
、shpecialized_xmap
、sholto_map
或sharad_map
。
**对于 pjit
用户**, shmap
是一种补充工具。它可以在 pjit
计算中使用,临时进入“手动集合”模式,就像编译器自动分区的一个逃生舱口。这样,用户可以获得 pjit
的便利性和熟悉的 NumPy 编程模型,用于其大部分代码,以及在需要时使用 shmap
手动优化集合通信的能力。这是两全其美!
**对于 pmap
用户**, shmap
是一个严格的升级。它更具表现力、更高效,并且可以与其他 JAX API 组合使用,而不会使基本批处理数据并行变得更加困难。
有关实际应用的更多信息,您可以跳转到 何时使用 shmap
以及何时使用 pjit
?。如果您想知道我们为什么需要一个新东西,或者 pmap
的问题是什么,请跳转到 pmap
或 xmap
为什么不能解决这个问题?。或者继续阅读下一节,了解一些 shmap
示例和 API 规范。
所以,让我们来看看 shmap
!#
TL;DR 示例(后面将有更详细的说明)#
Sho shick
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'j')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
注意
与
pmap
不同,不需要嵌套(或axis_index_groups
)来实现多个并行轴;与
pmap
和硬xmap
不同,调用者不需要进行重塑,并且逻辑形状对应于每个设备的物理形状,与(非硬)xmap
不同;通过使用
mesh
精确控制设备放置,与pmap
不同;与
xmap
不同,逻辑和物理只有一个轴名集;结果是一个
jax.Array
,可以有效地传递给pjit
,与pmap
不同;此代码在
pjit
/jit
内部也能有效地工作,与pmap
不同;此代码会积极地工作,因此我们可以在中间使用
pdb
并打印值,与xmap
的当前实现不同(虽然根据设计,没有顺序计划的xmap
原则上也可以积极地工作)。
这是另一个具有完全分片结果的矩阵乘法变体
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
慢下来,从基础开始!#
对数组轴进行降秩映射与保留秩映射#
我们可以将 pmap
(以及 vmap
和 xmap
)视为将每个数组输入沿着某个轴展开(例如,将一个二维矩阵展开成其一维行),将主体函数应用于每个部分,并将结果堆叠在一起,至少在没有集合的情况下。
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
例如,如果 xs
的形状为 f32[8,5]
,那么每个 x
的形状为 f32[5]
,如果每个 f(x)
的形状为 f32[3,7]
,那么最终堆叠的结果 pmap(f)(xs)
的形状为 f32[8,3,7]
。也就是说,主体函数 f
的每次应用都以比 pmap(f)
的相应参数少一个轴的输入作为参数。我们可以说这些是具有输入/输出展开/堆叠的降秩映射。
f
的逻辑应用次数由正在映射的输入轴的大小决定:例如,如果我们映射一个大小为 8 的输入轴,在语义上我们将获得 8 次逻辑函数应用,对于 pmap 来说,这始终对应于 8 个设备物理地计算它们。
相反,shmap
没有这种降秩行为。相反,我们可以把它想象成沿着输入轴切片(或“取消连接”)成块,应用主体函数,并将结果连接在一起(同样,在没有集合的情况下)。
devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
回想一下,jnp.split
将其输入切片成大小相等、秩相同的块,因此,如果在上述示例中 y
的形状为 f32[8,5]
,那么每个 y_blk
的形状为 f32[2,5]
,如果每个 f(y_blk)
的形状为 f32[3,7]
,那么最终连接的结果 shard_map(f, ...)(y)
的形状为 f32[12,7]
。所以 shmap
(shard_map
) 映射到其输入的分片或块。我们可以说它是一个具有输入/输出取消连接/连接的保留秩映射。
f
的逻辑应用次数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即,超过 4 个设备),那么在语义上我们将获得 4 次逻辑函数应用,对应于 4 个设备物理地计算它们。
使用 in_specs
控制每个输入的拆分(取消连接)和平铺#
每个 in_specs
使用 PartitionSpec
将相应输入数组的某些轴与网格轴通过名称标识,表示如何将该输入拆分成(或取消连接)将应用主体函数的块。该标识确定分片大小;当输入轴与网格轴标识时,输入将沿着该逻辑轴拆分(取消连接)成与相应网格轴大小相等的多个部分。(如果相应网格轴大小不能整除输入数组轴大小,则会出错。)如果输入的 pspec 未提及网格轴名称,则不会在该网格轴上进行拆分。例如
devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape)
return x_block
x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1) # prints (3,12)
这里,因为输入 pspec 未提及网格轴名称 'j'
,所以没有任何输入数组轴在该网格轴上进行拆分;类似地,因为输入数组的第二个轴没有与任何网格轴标识(因此没有在该轴上进行拆分),所以 f1
的应用获得了该轴上的完整输入视图。
当网格轴在输入 pspec 中未被提及时,我们始终可以重写成效率较低的程序,其中所有网格轴都被提及,但调用者执行 jnp.tile
,例如
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
换句话说,因为每个输入 pspec 可以提及每个网格轴名称零次或一次,而不是必须恰好提及一次,所以我们可以说除了其输入中内置的 jnp.split
之外,shard_map
还将其输入中内置了 jnp.tile
,至少在逻辑上(尽管平铺可能不需要物理地执行,具体取决于参数的物理分片布局)。要使用的平铺不是唯一的;我们也可以沿着第一个轴进行平铺,并使用 pspec P(('j', 'i'), None)
。
可以在输入上进行物理数据移动,因为每个设备都需要有相应数据的副本。
使用 out_specs
控制每个输出的组装方式,包括连接、块转置和取消平铺#
与输入端类似,每个 out_specs
通过名称将相应输出数组的某些轴与网格轴标识,表示如何将输出块(每个主体函数应用一个,或者等效地每个物理设备一个)组装在一起以形成最终的输出值。例如,在上面的 f1
和 f2
示例中,out_specs
指示我们应该通过沿着两个轴连接块结果来形成最终输出,这两种情况都会导致一个形状为 (12,24)
的数组 y
。(如果主体函数的输出形状,即输出块形状,的秩小于对应输出 pspec 所描述的连接,则会出错。)
当输出 pspec 中未提及网格轴名称时,表示不平铺:当用户写入不包含网格轴名称之一的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此输出中仅使用沿该轴的一个块(而不是将所有块沿该网格轴拼接在一起)。例如,使用上面相同的网格
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
注意,闭合在数组值上的体函数等效于将它作为增量传递,并带有相应的输入 pspec P(None, None)
。另一个例子,更接近上面其他例子
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) # (12,6)
注意,结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum
,在输出 pspec 中未提及网格轴名称 'j'
表示的非平铺是安全的,它确保每个输出块沿相应的网格轴相等。以下还有两个例子,我们在其中改变了输出 pspec 中提到的网格轴
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
在物理方面,在输出 pspec 中不提及网格轴名称会从输出设备缓冲区中组装一个 Array
,其沿该网格轴具有复制布局。
没有运行时检查来验证输出块是否实际上沿要非平铺的网格轴相等,或者等效地验证相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但我们可以提供一个静态检查机制,它会对所有可能不正确的程序引发错误。
因为 out_specs
可以提及网格轴名称零次或一次,而且它们可以按任何顺序提及,所以我们可以说,除了内置到其输出中的 jnp.concatenate
之外,shard_map
还将非平铺和块转置内置到其输出中。
无论输出 pspec 如何,物理数据移动在输出上都是不可能的。相反,out_specs
只是编码如何将块输出组装成 Array
,或者在物理上如何解释跨设备的缓冲区作为单个逻辑 Array
的物理布局。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
其中
mesh
编码以数组形式排列的设备以及关联的轴名称,就像它对xmap
和sharding.NamedSharding
一样;in_specs
和out_specs
是PartitionSpec
,它们可以仿射地提及来自mesh
的轴名称(不像xmap
中的单独逻辑名称)以分别表示输入和输出的切片/非拼接和拼接(不像pmap
和xmap
所做的那样进行解堆叠和堆叠),未提及的名称对应于复制和非平铺(断言复制,因此给我一份副本),分别;传递给
f
的参数的形状与传递给shard_map
-of-f
的参数的形状等级相同(与pmap
和xmap
不同,其中等级被降低),并且传递给f
的参数的形状是根据传递给shard_map
-of-f
的相应参数的形状shape
和相应的PartitionSpec
规范计算的,大致为tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
;f
的主体可以使用来自mesh
的名称来应用集体。
shmap
默认情况下是急切的,这意味着我们将逐个基元地分派计算,这样用户就可以对完全复制的值使用 Python 控制流,以及交互式 pdb
调试来打印任何值。要将 shmap
ed 函数分阶段推出并进行端到端编译,只需在它周围放置一个 jit
。结果是,shmap
没有像 xmap
和 pmap
那样具有自己的分派和编译路径;它只是 jit
路径。
当它被例如包含的 jit
分阶段推出时,shmap
向 StableHLO 的降低是微不足道的:它只涉及在输入上切换到“手动 SPMD 模式”,并在输出上切换回来。(我们目前不打算支持部分手动部分自动模式。)
与效果的交互与 pmap
相同。
与自动微分的交互也与 pmap
一样(而不是尝试 xmap
所做的新的语义,对应于具有未映射的中介,因此 grad
的 reduce_axes
以及使 psum
转置为 pbroadcast
而不是 psum
)。但这从 pmap
继承了一个未解决的问题:在某些情况下,而不是将 psum
转置为 psum
,从而执行与正向传递 psum
相对应的反向传递 psum
,将反向传递 psum
移动到反向传递中的其他位置,利用线性,可能会更有利。许多高级 pmap
用户通过使用 custom_vjp
来实现 psum_idrev
和 id_psumrev
函数来解决这个挑战,但是由于很容易不小心使它们不平衡,因此该技术是一种大炮。我们有一些想法可以更安全地提供此功能。
你什么时候应该使用 shmap
,什么时候应该使用 pjit
?#
一种哲学是:用 jit==pjit
编写程序几乎总是更简单——但如果程序的某个部分比可能的情况优化得更少,就进入 shmap
!
一个现实的 Transformer 例子#
事实上,我们可以使用 shmap
和 30 行 Python 代码来实现最近在 XLA 中引入的“集体矩阵乘法”算法的简单版本。该算法的基本思想可以用一个简单的例子来理解。
假设我们想要计算 C = A @ B
,其中 A
由一维网格在第 0 维上分片,而 B
和 C
是复制的。
M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
@jax.jit
def f(lhs, rhs):
return lhs @ rhs
C = f(A_x, B)
配置文件显示了在矩阵乘法开始之前在 8 个设备上进行阻塞式全收集。这并不理想,因为 A
在非收缩维度上被分片,并且 A
的每个分片可以独立地与 B
相乘,这种分块计算可以与从另一个设备获取 A
的下一个分片重叠。
可以使用 shmap
和显式集体来实现这种重叠。
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='i')
axis_index = jax.lax.axis_index(axis_name='i')
chunk_size = lhs.shape[0]
def f(i, carrys):
accum, lhs = carrys
# matmul for a chunk
update = lhs @ rhs
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='i',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
# device 1 computes chunks 1, 2, ...
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum, lhs
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
for i in range(0, axis_size - 1):
accum, lhs = f(i, (accum, lhs))
# compute the last chunk, without the ppermute
update = lhs @ rhs
i = axis_size - 1
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
配置文件显示全收集消失了,取而代之的是与异步集体置换重叠的矩阵乘法。此配置文件与集体矩阵乘法论文结果非常吻合。
这种集体矩阵乘法技术可用于加速 Transformer 层中的前馈块。这通常包括两个矩阵乘法,然后是 ReduceScatter
(以解决来自并行化矩阵乘法的部分和),并在前面是 AllGather
(以收集沿某些轴的分片维度,并允许部分和计算)。总的来说,来自一层中的 ReduceScatter
和下一层中的 AllGather
相当于一个 AllReduce
。
在典型的配置文件中,两个矩阵乘法将后跟一个 AllReduce
,并且它们不会重叠。集体矩阵乘法可以用来实现重叠,但难以触发,有最小切片大小,并且尚未涵盖所有拓扑、张量形状和集体矩阵乘法的变体(即延迟和吞吐量优化的变体)。在最近的一篇论文中,我们发现通过在 shmap
风格中手动实现集体矩阵乘法的变体,在许多情况下可以获得约 40% 的收益。
但这并不总是更复杂!我们预计这将是思考流水线计算的一种更自然的方式,并计划尽快进行一些演示!
另一个现实的例子#
以下是如何在 Transformer 层传递中使用具有二维权重收集模式的 shmap
(论文,第 3.2.3 节,第 5 页)
def matmul_2D_wg_manual(xnorm, q_wi, layer):
'''Calls a custom manual implementation of matmul_reducescatter'''
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
# -> (matmul)
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
# -> (reducescatter over x into X heads, B batches)
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
with jax.named_scope('q_wi'):
xnorm = intermediate_dtype(xnorm)
q_wi = matmul_reducescatter(
'bte,hed->bthd',
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='i',
layer=layer)
return q_wi
import partitioning.logical_to_physical as l2phys
def pjit_transformer_layer(
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Forward pass through a single layer, returning output, K, V."""
def my_layer(t, axis=0):
"""Gets the parameters corresponding to a given layer."""
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
# 2D: [batch.Z, time, embed.XY]
x = _with_sharding_constraint(
x, ('residual_batch', 'residual_time', 'residual_embed'))
xnorm = _layernorm(x)
# 2D: [batch, time, embed.X]
xnorm = _with_sharding_constraint(
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
# jump into manual mode where you want to optimise
if manual:
q_wi = shard_map(matmul_2D_wg_manual, mesh
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
else:
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
# 2D: [batch, time, heads.YZX, None]
q_wi = _with_sharding_constraint(q_wi,
('post_norm_batch', 'time', 'heads', 'qkv'))
q = q_wi[:, :, :, :hparams.qkv]
q = _rope(sin, cos, q)
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
# swiGLU with full d_ff dimension, rather than 2/3 scaled
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
k = kv[:, :, 0, :hparams.qkv]
v = kv[:, :, 0, hparams.qkv:]
k = _rope(sin, cos, k)
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
y_mlp = special2.swish2(wi0) * wi1
# 2D: [batch, time, heads.YZX, None]
y_mlp = _with_sharding_constraint(y_mlp,
('post_norm_batch', 'time', 'heads', None))
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
# do the second half of the mlp and the self-attn projection in parallel
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
# 2D: [batch.Z, time, embed.XY]
y_out = _with_sharding_constraint(
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
z = y_out + x
z = _with_sharding_constraint(
z, ('residual_batch', 'residual_time', 'residual_embed'))
return z, k, v
在下方的配置文件中,第一个和第二个矩阵乘法都被替换为手动降低的版本,其中计算(融合)与通信(ppmermute)完全重叠!一个有趣的提示表明我们正在使用延迟优化变体是 ppmerute 像素是抖动的 - 因为有两个重叠的 ppermutes 同时使用相反的 ICI 轴!
全对全更难重叠,因此被搁置一旁。
为什么 pmap
或 xmap
还没有解决这个问题?#
pmap
是我们第一个多设备并行 API。它遵循每个设备代码和显式集合的模式。但它有重大缺陷,使其不适合当今的程序
映射多个轴需要嵌套的
pmap
。 嵌套的pmap
不仅编写起来很麻烦,而且它们还难以控制(甚至预测)数据和计算的设备放置,并且难以保留数据分片(见下两个要点)。今天的程序需要多轴并行。控制设备放置是不可能的。 尤其是在多轴并行的情况下,程序员需要控制这些轴如何与硬件资源及其通信拓扑对齐。但是(嵌套的)
pmap
不提供对映射的程序实例如何在硬件上放置的控制;只有一个自动设备顺序,用户无法控制。(Gopher 使用axis_index_groups
和单个非嵌套pmap
本质上是解决此问题的 hack,通过将多个并行轴扁平化到一个轴。)jit
/pjit
可组合性。jit
-of-pmap
是一个性能陷阱,就像嵌套pmap
那样,就像例如scan
-of-pmap
那样,因为从内部pmap
返回时不会保留分片。为了保留分片,我们需要对 jaxprs 进行模式匹配,以确保我们正在使用完美嵌套的 pmap,或者一个 pmap 就在jit
内部。此外,pjit
在这里没有帮助,因为pmap
针对 XLA 复制品,而pjit
针对 XLA SPMD 分区器,这两个的组合很困难。jax.Array
兼容性(以及pjit
兼容性)。 由于pmap
输出的分片无法表示为Shardings
/OpShardings
,这是由于pmap
的堆叠而不是连接语义,因此pmap
计算的输出目前无法传递给pjit
计算,而无需反弹到主机(或调度重新整形计算)。多控制器语义(以及
pjit
兼容性)。 多控制器pmap
在控制器之间连接值,这工作得很好,但与单控制器pmap
的堆叠语义不同。更实际地说,它排除了使用非完全可寻址的jax.Array
输入和输出,正如我们在多控制器pjit
中所使用的那样。急切模式。 我们没有使
pmap
成为急切优先的,尽管我们最终(在 4 多年后!)添加了急切操作,使用disable_jit()
,但pmap
具有jit
融合到其中的事实意味着它有自己的编译和调度路径(实际上有两个调度路径:在 Python 中用于处理Tracer
,以及在 C++ 中用于在原始Array
输入上的性能!),这是沉重的实现负担。调用者中需要重新整形。 使用
pmap
在 8 个设备上的典型用例可能看起来像从大小为 128 的批处理轴开始,将其重新整形为分成两个大小为 (8, 16) 的轴,然后在第一个轴上进行pmap
。这些重新整形很笨拙,编译器通常将它们解释为副本而不是视图 - 增加了内存和时间使用量。
当只做批处理数据并行时,这些缺点并不那么糟糕。但当涉及到更多的并行性时,pmap
就不行了!
xmap
为 pmap
的下一代进化铺平了道路,并解决了(几乎)所有这些问题。 shmap
遵循 xmap
的脚步,并以本质上相同的方式解决了这些问题;实际上,shmap
就像 xmap
的一个专门子集(有些人称之为“硬 xmap
”子集),做了一些调整。
对于最初的原型,我们选择将 shmap
作为与 xmap
不同的基元实现,因为限制它支持的功能集使其更容易专注于核心功能。例如,shmap
不允许未映射的中间值,这使得更容易不必担心命名轴和自动微分之间的交互。此外,不必推理所有对功能的交互,这使得更容易添加超出 xmap
今天实现的功能,例如对急切模式的支持。
shmap
和 xmap
共享大量降低代码。我们将来可以考虑合并两者,甚至只专注于 shmap
,具体取决于使用方式如何演变。