使用 shard_map
的手动并行#
概述#
shard_map
是一个单程序多数据 (SPMD) 多设备并行 API,用于将函数映射到数据分片上。映射的函数应用,或实例,通过显式的集体通信操作相互通信。
shard_map
与内置于 jit
中的自动编译器并行化是互补的,并且可以组合使用。使用 jit
时,您编写的代码就像是针对单个设备一样,并且编译器可以自动将计算划分为多个设备,在后台生成每个设备的代码和通信集合。使用 shard_map
时,您可以掌控一切,编写自己的分区代码和显式集合。或者,您可以两者都做:在设备组之间进行手动控制,同时将组内的设备分区留给编译器。这两种方法可以根据需要混合、匹配和组合。
如果您熟悉 pmap
,可以将 shard_map
视为一种演进。它更具表现力、性能更好,并且可以与其他 JAX API 组合使用。它甚至可以立即运行,以便于调试!(有关详细信息,请参阅与 pmap
的详细比较。)
通过阅读本教程,您将学习如何使用 shard_map
来完全控制您的多设备代码。您将详细了解它如何与 jax.jit
的自动并行化和 jax.grad
的自动微分组合使用。我们还将提供一些神经网络并行化策略的基本示例。
我们假设本教程在具有八个设备的环境中运行
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
那么,让我们来看一个 shard_map
!#
事不宜迟,这是一个玩具示例
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((4, 2), ('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
此函数通过执行局部块矩阵乘法,然后进行集体求和运算,并行计算矩阵乘法。我们可以检查结果是否正确
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
True
结果沿着其行分片
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
从高层次来看,shard_map
类似于 vmap
或 pmap
,因为我们正在将函数映射到数组数据的各个部分上,但是请注意
shard_map
将输入切分成块(并且输出是通过连接结果块形成的),保持秩不变,而vmap
会通过映射掉一个轴来降低秩;mesh
参数允许我们控制计算和结果的精确设备放置;我们正在同时映射多个数据轴,并为集合设置多个轴名称(此处为
'x'
和'y'
);由于我们还没有使用
jax.jit
,因此所有内容都是立即计算的,我们甚至可以print
中间值以进行调试。
以上代码执行的计算与此 jax.jit
自动并行化代码相同
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True
我们可以将 shard_map
视为根据其 mesh
和 in_specs
参数对其输入执行 device_put
或 with_sharding_constraint
,因此 matmul_basic
操作的块与 matmul_reference
中的块相同
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
慢下来,从基础开始!#
降秩映射与保持秩映射#
我们可以将 vmap
和 pmap
视为沿着轴解堆叠每个数组输入(例如,将二维矩阵解包成其一维行),将它的主体函数应用于每个片段,并将结果堆叠在一起,至少在不涉及集合时是这样
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
例如,如果 xs
的形状为 f32[8,5]
,则每个 x
的形状将为 f32[5]
,并且如果每个 f(x)
的形状为 f32[3,7]
,则最终堆叠结果 vmap(f)(xs)
的形状将为 f32[8,3,7]
。也就是说,主体函数 f
的每次应用都会将比 vmap(f)
的相应参数少一个轴的输入作为参数。我们可以说这些是输入/输出解堆叠/堆叠的降秩映射。
f
的逻辑应用次数,或 f
的实例,由映射的输入轴的大小决定:例如,如果我们映射大小为 8 的输入轴,则语义上我们得到该函数的 8 个逻辑应用。
相反,shard_map
没有这种降秩行为。相反,我们可以将其视为沿着输入轴切片(或“解连接”)成块,应用主体函数,并将结果连接在一起(同样,在不涉及集合时)
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
回想一下,jnp.split
将其输入切片成相同大小且秩相同的块,因此,如果在上面的示例中,y
的形状为 f32[8,5]
,则每个 y_blk
的形状将为 f32[2,5]
,并且如果每个 f(y_blk)
的形状为 f32[3,7]
,则最终连接的结果 shard_map(f, ...)(y)
的形状将为 f32[12,7]
。因此,shard_map
映射其输入的分片或块。我们可以说它是输入/输出解连接/连接的保持秩映射。
f
的逻辑应用次数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即在 4 个设备上),则语义上我们得到该函数的 4 个逻辑应用,对应于实际计算它们的 4 个设备。
使用 in_specs
控制如何分割(解连接)和铺平每个输入#
每个 in_specs
都通过使用 PartitionSpec
按名称将对应输入数组的一些轴标识为网格轴,表示如何将该输入拆分(或解连接)为应用主体函数的块。该标识确定分片大小;当输入轴被标识为网格轴时,输入将沿着该逻辑轴拆分(解连接)为等于相应网格轴大小的多个片段。(如果对应的网格轴大小不能均匀地分割输入数组轴大小,则会出错。)如果输入的 pspec 未提及网格轴名称,则不会在该网格轴上进行拆分。例如
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)
在此处,由于输入 pspec 未提及网格轴名称 'j'
,因此没有输入数组轴在该网格轴上拆分;类似地,由于输入数组的第二个轴未被标识为(因此未在该轴上拆分)任何网格轴,因此 f1
的应用程序可以沿该轴完全查看输入。
当输入 pspec 中未提及网格轴时,我们始终可以重写为效率较低的程序,其中提及所有网格轴,但调用者执行 jnp.tile
,例如
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
换句话说,由于每个输入 pspec 可以零次或一次提及每个网格轴名称,而不是必须精确地提及每个名称一次,因此我们可以说,除了其输入中内置的 jnp.split
之外,shard_map
的输入中还内置了一个 jnp.tile
,至少在逻辑上是这样(尽管平铺可能不需要物理执行,具体取决于参数的物理分片布局)。要使用的平铺不是唯一的;我们也可以沿着第一个轴平铺,并使用 pspec P(('j', 'i'), None)
。
输入上可能发生物理数据移动,因为每个设备都需要拥有适当的数据副本。
使用 out_specs
控制如何通过连接、块转置和解平铺来组装每个输出#
与输入端类似,每个 out_specs
通过名称将相应的输出数组的某些轴与网格轴相关联,表示如何将输出块(每个函数体应用一个,或者等效地,每个物理设备一个)组装在一起,形成最终的输出值。例如,在上面的 f1
和 f2
示例中,out_specs
指示我们应该将沿两个轴的块结果连接在一起,从而在两种情况下都得到形状为 (12, 24)
的数组 y
。(如果函数体的输出形状(即输出块形状)的秩太小,无法满足相应输出 pspec 描述的连接,则会出错。)
当输出 pspec 中未提及网格轴名称时,它表示取消分片:当用户编写一个未提及其中一个网格轴名称的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此在输出中仅使用沿该轴的一个块(而不是将所有块沿该网格轴连接在一起)。例如,使用与上面相同的网格
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
闭包一个数组值的函数体等效于将该数组值作为具有相应的 P(None, None) 输入 pspec 的参数传递。作为另一个例子,更接近上面的其他例子
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)
结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum
,在输出 pspec 中不提及网格轴名称 'j'
表示的取消分片是安全的,这确保每个输出块沿相应的网格轴相等。以下是另外两个例子,我们改变了输出 pspec 中提及的网格轴
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
(3, 12)
(3, 6)
在物理方面,在输出 pspec 中不提及网格轴名称会从具有沿该网格轴复制布局的输出设备缓冲区组装一个 Array
。
没有运行时检查输出块是否真的沿要取消分片的网格轴相等,或者等效地,相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一个静态检查机制,该机制会在所有可能不正确的程序上引发错误。
由于 out_specs
可以零次或一次提及网格轴名称,并且因为它们可以以任何顺序提及,我们可以说,除了内置于其输出的 jnp.concatenate
之外,shard_map
还在其输出中内置了取消分片和块转置。
无论输出 pspec 如何,都不可能在输出上进行物理数据移动。相反,out_specs
仅编码如何将块输出组装为 Array
,或者物理上如何将设备上的缓冲区解释为单个逻辑 Array
的物理布局。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
check_rep: bool = True,
) -> Callable:
...
其中
诸如
f
主体中的psum
之类的通信集合可以提及mesh
的轴名称;mesh
编码排列在数组中的设备以及相关的轴名称,就像它为sharding.NamedSharding
所做的那样;in_specs
和out_specs
是PartitionSpec
,可以仿射提及mesh
中的轴名称,以分别表示输入和输出的分片/取消连接和连接,未提及的名称分别对应于复制和取消分片(断言复制-所以给我一个副本);auto
是一个可选的轴名称集合,对应于在函数体中自动处理的mesh
名称的子集,就像在调用者中一样,而不是手动处理;check_rep
是一个可选的布尔值,指示是否静态检查out_specs
中的任何复制错误,以及是否启用相关的自动微分优化(请参阅 JEP)。
传递给 f
的参数的形状与传递给 shard_map
-of-f
的参数的形状具有相同的秩,并且传递给 f
的参数的形状是根据传递给 shard_map
-of-f
的相应参数的形状 shape
和相应的 PartitionSpec
spec
计算出来的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
。
集合教程#
shard_map
不必是纯映射:函数应用程序可以通过集合进行相互通信,使用在 mesh
参数中定义的轴名称。
回想一下,shard_map
将一个函数映射到输入数据的分片或块上,因此
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
计算相同的值,将 f
的应用程序评估为相同的参数值,就像这个参考函数一样
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
我们将 f
应用于不同的参数分片称为函数实例。每个函数实例在不同的设备(或设备子集)上执行。
当 f
中没有通信集合时,这些参考语义起作用。但是,如果我们希望函数实例进行通信,这对应于进行跨设备通信,该怎么办?也就是说,当 f
包含集合时,参考语义是什么?假设 f
只有一个集合,并且形式为
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
其中我们假设只映射一个网格轴,并且 axis_name
是其对应的名称。那么参考语义将更像
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
请注意,collective_ref
可能取决于所有 z_blocks
。也就是说,虽然 f_part1
和 f_part2
是独立映射到块上的,但集合会引入一定量的跨块依赖关系。在物理上,这意味着跨设备的通信。究竟发生什么通信以及计算什么值取决于集合。
psum
#
最简单的集合可能是 jax.lax.psum
,它计算沿设备网格轴(或多个轴)的 all-reduce-sum。这是一个玩具示例
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
打印结果显示,每个函数应用程序都以其自己的参数值 x_block
的块开始。在 psum
之后,每个函数应用程序都具有相同的 y_block
值,该值是通过将应用程序的 x_block
值加在一起计算得出的。
在计算中只有一个轴名称的情况下,我们可以说 psum
的 collective_ref
参考实现为
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
另请注意,由于 f1
返回 y_block
,即在 'i'
上进行 psum
的结果,我们可以使用 out_specs=P()
,因此调用者会获得结果值的单个逻辑副本,而不是分片的结果。
当有多个网格轴时,我们可以分别对每个网格轴执行 psum
,也可以一次对多个轴执行
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
通过在网格轴 'i'
上应用 psum
,我们得到沿轴 ‘i'
相等,但沿轴 'j'
不相等的 y_block
值。(因此,我们可以使用 out_specs=P(None, 'j')
来获得沿该轴的单个逻辑结果。)
如果我们对两个轴都应用 psum
,则 y_block
值在两个轴上都相等
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
在机器学习中,我们经常使用 psum
来计算总损失,或者当我们在 shard_map
映射的函数体内部有 grad
时,计算总梯度。
接下来,我们将看到如何用其他原语实现 psum
,这将使我们对它的通信成本有一些直观的了解。
all_gather
#
另一个基本操作是沿轴收集数组分片,以便每个函数应用都具有沿该轴的完整数据副本。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
打印结果显示,每个函数应用都再次从其自身的参数值 x_block
开始。在 all_gather
之后,它们有一个共同的值,该值是通过连接 x_block
的值计算得出的。
(请注意,实际上我们不能在此处设置 out_specs=P()
。由于与自动微分相关的技术原因,我们认为 all_gather
的输出不能保证在设备之间保持不变。如果我们想保证它保持不变,我们可以使用 jax.lax.all_gather_invariant
,或者在这种情况下,我们可以避免在函数体中执行 all_gather
,而是直接使用 out_specs=P('i')
来执行连接。)
当 tiled=False
(默认值)时,结果会沿新轴堆叠,而不是连接。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
我们可以为 all_gather
编写 collective_ref
参考语义函数,如下所示:
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
在深度学习中,我们可能会在完全分片的数据并行 (FSDP) 中对参数使用 all_gather
。
psum_scatter
#
jax.lax.psum_scatter
集合操作有点不太直观。它类似于 psum
,只是每个函数实例只得到结果的一个分片。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
如打印结果所示,与 psum
不同,每个结果 y_block
的大小都比参数 x_block
小。此外,与 psum
相比,这里的每个 y_block
仅表示跨函数实例的 x_block
总和的一个切片。(尽管每个函数实例只获得总和的一个分片,但最终输出 y
与 psum
示例中的相同,因为这里我们使用 out_specs=P('i')
来连接每个函数实例的输出。)
就计算的值而言,collective_ref
参考实现可能如下所示:
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
在语义参考实现中没有捕获,但 psum_scatter
非常有用,因为这些结果可以更有效率地计算,通信量更少,而不是完整的 psum
。实际上,可以将 psum_scatter
视为“在 all_gather
之前的 psum
的前半部分”。也就是说,实现 psum
的一种方法是:
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
实际上,这种实现经常在 TPU 和 GPU 上使用!
psum_scatter
可能需要大约一半的通信量才能达到完整的 psum
的原因是 ppermute
部分中说明的。
另一个直观的理解是,我们可以使用 psum_scatter
来实现分布式矩阵乘法,其输入和输出沿同一轴分片。在机器学习中,psum_scatter
可用于张量并行矩阵乘法或完全分片的数据并行梯度累积,如下面的示例所示。
ppermute
#
jax.lax.ppermute
集合操作为函数实例提供了一种最直接的相互发送数据的方式。给定一个网格轴和表示沿该网格轴的索引的 (source_index, destination_index)
对列表,ppermute
将其参数值从每个源函数实例发送到每个目标。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.psum(1, 'i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
在这种情况下,只有两个函数实例,每个实例的 y_block
值是另一个实例的 x_block
值。
源索引和目标索引不能重复。如果一个索引没有作为目标出现,则相应函数实例结果的值是一个零数组。
collective_ref
参考实现可能如下所示:
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
可以使用 ppermute
有效地实现其他集合操作,在总通信方面,其中每个函数只将数据传递给它的邻居。例如,我们可以使用一系列 ppermute
和局部加法来实现 psum_scatter
,如下所示:
或者,使用数值示例:
直观地说,在每次迭代中,每个函数实例都会将它在前一次迭代中收到的值“向上”发送,并减少(添加)它在这次迭代中收到的值。在代码中,它可能看起来像这样:
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.psum(1, axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
在 TPU 上,有这种算法的高维变体,可以利用多个双向物理网格轴。
请注意,psum_scatter
是 all_gather
的转置。实际上,用 ppermute
实现 all_gather
的一种方法看起来与上述过程相反:
在深度学习中,我们可能会在实现 SPMD 流水线并行时使用 ppermute
,其中我们将网络沿深度划分为多个阶段,并并行评估阶段的应用。或者,我们可能会在并行化卷积层的评估时使用 ppermute
,其中我们沿空间轴分片,因此设备必须相互通信“晕轮”。或者,它可能在张量并行矩阵乘法中在底层使用。
all_to_all
#
最后一个集合操作是 all_to_all
,它本质上是沿一个位置轴和一个跨设备轴操作的块矩阵转置。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
split_axis
参数指示应在网格轴上分片和分区的哪个位置轴。concat_axis
参数指示应将通信结果连接或堆叠的轴。
当 tiled=False
(默认值)时,split_axis
轴大小必须等于名为 axis_name
的网格轴的大小,并且在该大小的新的轴在位置 concat_axis
处创建,用于堆叠结果。当 tiled=True
时,split_axis
轴大小只需要能被网格轴的大小均匀整除,并且结果将沿现有轴 concat_axis
连接。
当 split_axis=0
和 concat_axis=0
时,collective_ref
参考语义可能如下所示:
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
在深度学习中,我们可能会在混合专家路由中使用 all_to_all
,其中我们首先根据示例应该转到哪个专家对本地批处理示例进行排序,然后应用 all_to_all
将示例重新分发给专家。
玩具示例#
我们如何在实践中使用 shard_map
和集体通信?这些例子虽然简单,但也给出了一些思路。
矩阵乘法#
并行化矩阵乘法是扩展深度学习模型的关键,无论是在训练还是在推理中。当 jax.jit
自动并行化矩阵乘法时,它可以根据矩阵大小、硬件细节和其他因素使用多种不同的策略。我们如何使用 shard_map
更明确地编写其中一些并行化例程?以及我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:单侧 all-gather
#
考虑执行矩阵乘法,其中我们将左侧参数(可以认为是:参数)在其前导(非收缩)维度上进行分片
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
我们将右侧参数(可以认为是:激活)在其收缩维度上进行分片,输出也进行类似的分片
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
为了执行此矩阵乘法,我们可以先对右侧进行 all-gather,然后对分片的左侧执行局部矩阵乘法
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
这很好,但我们在这里没有获得任何计算/通信重叠:在我们开始矩阵乘法之前,我们需要 all_gather
完成。这是使用相同代码但在更大的示例形状上(lhs
为 (8192, 8192)
,rhs
为 (8192, 1024)
)的配置文件
如果我们不调用 all_gather
,而是基本上内联我们上面使用 ppermute
实现的 all_gather
,然后将收集置换的步骤与局部矩阵乘法交错,我们就可以获得计算/通信重叠
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size
lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
out_block = lhs_blocks(idx) @ rhs_block
for i in range(1, size):
rhs_block = shift(rhs_block)
out_block += lhs_blocks((idx - i) % size) @ rhs_block
return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
此实现允许通信和计算之间的重叠,并且还避免在每个设备上收集一个大的中间结果。但在 TPU 上,它通过仅沿环的一个方向置换来使用一半的互连带宽。为了进行双向置换,我们只需将块分成两半,然后将每一半发送到每个方向
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
out_block = lhs_blocks(idx, 0) @ rhs_block_lo
out_block += lhs_blocks(idx, 1) @ rhs_block_hi
for i in range(1, size):
rhs_block_lo = shift_up(rhs_block_lo)
rhs_block_hi = shift_dn(rhs_block_hi)
out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
在实践中,为了减少编译时间,我们可能会将其融入 jax.lax.fori_loop
。我们也可能涉及额外的并行轴。
示例 2:psum_scatter
结果#
我们可能开始的另一种分片方式是将 lhs
和 rhs
都沿着它们的收缩维度进行分片,输出再次像 rhs
一样进行分片
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)
在这里,我们可以使用 reduce_scatter
来执行跨分片的收缩求和
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
out_summand = lhs_block @ rhs_block
return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
但是,散射通信必须等待整个局部矩阵乘法完成才能开始。为了获得通信/计算重叠,我们可以内联使用 ppermute
实现的 psum_scatter
,然后将通信步骤与局部矩阵乘法交错
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
out_summand = lhs_block[(idx + 1) % size] @ rhs_block
for i in range(1, size):
out_summand = shift(out_summand)
out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
与前面的示例一样,为了充分利用 TPU 上的互连,我们将运行一个双向版本
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[0] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
for i in range(1, size):
out_summand_lo = shift_up(out_summand_lo)
out_summand_hi = shift_dn(out_summand_hi)
out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
神经网络#
我们可以使用 shard_map
来并行化神经网络中的计算,既可以单独使用,也可以与 jax.jit
中的自动分区结合使用。本节包含一些基于此玩具神经网络和随机数据的示例
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
将这些示例与“分布式数组和自动分区”文档中的纯自动分区示例进行比较。虽然在这些自动分区示例中,我们无需编辑模型函数即可使用不同的并行化策略,但在使用 shard_map
时,我们通常需要这样做。
8 路批处理数据并行#
最简单的多设备并行策略是将输入和目标批次分片到多个设备上,在这些设备上复制参数,并并行地将模型应用于这些数据分片。为了评估总损失,设备只需要在最后进行标量大小的 all-reduce-sum 通信。(为了评估损失的梯度,设备必须在反向传播中执行参数梯度的 all-reduce-sum。)
from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('batch',))
# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(batch)
我们可以检查损失及其梯度是否与参考(基本)模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
11.920298
11.920298
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_dp))(params, batch)))
True
我们可以打印编译器 IR 来检查梯度计算,并验证集体 all-reduce-sum 操作是否发生在我们期望的位置:在前向传播结束时计算损失值,并在反向传播中计算总参数梯度。
8 路完全分片数据并行 (FSDP)#
另一种策略是在设备上额外分片参数,在需要完整值进行 jnp.dot
或偏差相加时收集每个参数。由于我们一次只有一个完整的参数在本地设备内存中,而不是像前面的 DP 示例中那样将所有参数保存在所有设备内存中,因此我们释放了大量内存,可用于更大的模型或更大的批次大小。而且由于 XLA 将重叠计算和设备间通信,因此挂钟时间不会受到影响。
因此,现在我们需要在两个地方进行集合:模型预测函数 predict
需要在参数使用之前收集它们,并且与 DP 情况一样,损失函数需要将局部损失相加以计算总损失。
我们需要另一个要素:我们不想存储前向传播中完全收集的参数以在反向传播中使用。相反,我们希望在反向传播中再次收集它们。我们可以通过使用带有自定义策略(或 custom_vjp
)的 jax.remat
来表示这一点,尽管 XLA 通常会自动执行这种重物化。
这种通用的 FSDP 方法类似于权重更新分片 (WUS) 和 ZeRO-3。
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss_fsdp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
def loss_spmd(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp(local_params, inputs)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(params, batch)
我们可以再次检查损失及其梯度是否与参考模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
11.920299
11.920299
True
8 路张量并行 (TP)#
通常我们不单独使用张量模型并行,但单独查看它是并行矩阵乘法的一个很好的热身。这也是在更大的基于 jit
的计算中调用的库函数中使用 shard_map
的一个很好的例子。
并行化思想是,我们将数据/激活在其特征轴上(而不是其批处理轴上)进行分片,并且我们将类似地将权重矩阵在其输入特征轴上(以及偏差在其特征轴上)进行分片。然后,为了执行并行矩阵乘法,我们将执行局部矩阵乘法,然后执行 psum_scatter
以对局部结果求和并有效地分散结果的分片。
mesh = jax.make_mesh((8,), ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
for W, b in params:
outputs = gemm_tp(inputs, W, b)
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
block_result = jnp.dot(inputs, W)
return jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
inputs, targets = batch
predictions = predict_tp(params, inputs)
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,在顶层使用 shard_map
#
我们可以将这些策略组合在一起,使用多个并行轴。
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
block_result = jnp.dot(inputs, W)
outputs = jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp_tp(local_params, inputs)
sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
return jax.lax.pmean(jnp.mean(sq_err), 'batch')
请注意,我们必须进行两次集体缩减:一次在 'feats'
上,一次在 'batch'
上。在纯 TP 示例中,我们没有显式编写 'feats'
缩减,因为我们仅在 gemm_tp
中使用了 shard_map
;在调用方 loss_tp
中,编译器根据 predict_tp
返回的分片结果,自动将我们对 jnp.sum
的使用转换为按需执行的 psum
。
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
11.920299
11.920298
True
SPMD 流水线并行 (PP)#
通过流水线并行,我们的目标是并行化网络中不同深度的层的评估。例如,一个设备可能计算第一层的应用,而另一个设备计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并且该过程重复进行。通常,流水线级的数量可能与层的数量不同,因为每个级可能负责多个层。
通过 SPMD 流水线,我们利用了网络中的大多数层都应用计算的事实,只是参数值不同。特别是,我们可以将除第一层和最后一层之外的所有参数堆叠在一起,然后使用 shard_map
来映射这些层参数的块,其中每个参数块对应于一个流水线级。然后,我们使用 jax.lax.ppermute
集体将数据向下移动并行流水线。
这种特定的流水线策略本质上是 GPipe 策略。有几种变体,以及完全不同的策略,哪种策略合适可能取决于阶段之间的网络速度和批次大小。但是对于本教程,我们将仅关注一种策略。
首先,我们选择一些流水线参数
L = len(params) - 2 # num layers, excluding first and last
N = batch_size # batch size
F = params[0][0].shape[1] # num features
# choose some pipeline parameters
S = 2 # number of stages
B = 8 # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B) # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S) # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
(W_first, b_first), inner_params, (W_last, b_last) = params
inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
inner_params, inputs)
outputs = jnp.dot(inputs, W_last) + b_last
return outputs
@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
out_specs=P())
def loss_pp(params, batch):
inputs, targets = batch
predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
stage = jax.lax.axis_index('stages')
outputs = jnp.zeros_like(inputs) * jnp.nan
state = jnp.zeros((L // S, B, F)) * jnp.nan
for i in range(M+L-1):
state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
state = jax.vmap(fn)(stage_params, state)
outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
state, inputs, outputs = shift(i, state, inputs, outputs)
outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
return outputs
def shift(i, state, inputs, outputs):
sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
if (i % K) == (-1 % K):
inputs = sh(inputs, +1)
if ((i-L+1) % K) == (-1 % K):
outputs = sh(outputs, +1)
return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
11.920299
11.9203
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash