使用 shard_map
进行手动并行#
概述#
shard_map
是一个单程序多数据 (SPMD) 多设备并行 API,用于将函数映射到数据分片。映射的函数应用程序或实例通过显式集体通信操作相互通信。
shard_map
与内置于 jit
中的自动编译器并行化是互补的,并且可以组合使用。使用 jit
,您可以像为单个设备编写代码一样,并且编译器可以自动将计算分配到多个设备上,在后台生成每个设备的代码和通信集合。使用 shard_map
,您可以控制编写自己的分区代码和显式集合。或者,您可以两者都做:在设备组之间进行手动控制,同时将组内设备分区留给编译器处理。这两种方法可以根据需要混合、匹配和组合。
如果您熟悉 pmap
,则可以将 shard_map
视为一种演变。它更具表现力、性能更高,并且可以与其他 JAX API 组合使用。它甚至可以急切地工作,以便于调试!(有关更多信息,请参阅与 pmap
的详细比较。)
通过阅读本教程,您将学习如何使用 shard_map
来完全控制您的多设备代码。您将详细了解它如何与 jax.jit
的自动并行化和 jax.grad
的自动微分组合使用。我们还将提供一些神经网络并行化策略的基本示例。
我们将假设本教程在具有八个设备的环境中运行
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
因此,让我们看看 shard_map
!#
事不宜迟,这是一个玩具示例
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((4, 2), ('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
此函数通过执行局部块矩阵乘法,然后执行集体求和操作来并行计算矩阵乘法。我们可以检查结果是否正确
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
True
结果沿其行进行分片
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
从高层次上讲,shard_map
有点像 vmap
或 pmap
,因为我们正在将一个函数映射到数组数据的片段上,但请注意
shard_map
将输入切成块(并且输出是通过连接结果块形成的),保持秩不变,而vmap
会通过映射掉一个轴来降低秩;mesh
参数让我们控制计算和结果的精确设备放置;我们正在一次映射多个数据轴,并为集合设置多个轴名称(此处为
'x'
和'y'
);由于我们尚未使用
jax.jit
,因此所有内容都会被急切地评估,我们甚至可以print
中间值以进行调试。
上面的代码执行与此 jax.jit
自动并行化代码相同的计算
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True
我们可以将 shard_map
视为根据其 mesh
和 in_specs
参数对其输入执行 device_put
或 with_sharding_constraint
,因此 matmul_basic
操作的块与 matmul_reference
中的块相同
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
放慢速度,从基础知识开始!#
降秩映射与保秩映射#
我们可以将 vmap
和 pmap
视为沿轴拆堆每个数组输入(例如,将 2D 矩阵解压为其 1D 行),将其主体函数应用于每个片段,然后将结果堆叠在一起,至少在不涉及集合时是这样
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
例如,如果 xs
的形状为 f32[8,5]
,则每个 x
的形状将为 f32[5]
,并且如果每个 f(x)
的形状为 f32[3,7]
,则最终堆叠结果 vmap(f)(xs)
的形状将为 f32[8,3,7]
。也就是说,主体函数 f
的每次应用都以比 vmap(f)
的相应参数少一个轴的输入作为参数。我们可以说这些是具有输入/输出拆堆/堆叠的降秩映射。
f
的逻辑应用次数或 f
的实例由正在映射的输入轴的大小决定:例如,如果我们映射大小为 8 的输入轴,则在语义上我们会得到 8 个函数的逻辑应用。
相比之下,shard_map
没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“取消连接”)成块,应用主体函数,然后将结果连接在一起(同样,在不涉及集合时)
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
回想一下,jnp.split
将其输入切片成大小相等的同秩块,因此如果在上面的示例中 y
的形状为 f32[8,5]
,则每个 y_blk
的形状将为 f32[2,5]
,并且如果每个 f(y_blk)
的形状为 f32[3,7]
,则最终连接结果 shard_map(f, ...)(y)
的形状将为 f32[12,7]
。因此,shard_map
映射其输入的分片或块。我们可以说它是一个具有输入/输出取消连接/连接的保秩映射。
f
的逻辑应用次数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即在 4 个设备上),那么在语义上我们会得到 4 个函数的逻辑应用,对应于物理计算它们的 4 个设备。
使用 in_specs
控制如何拆分(取消连接)和平铺每个输入#
每个 in_specs
都使用 PartitionSpec
按名称标识相应的输入数组的某些轴与网格轴,表示如何将该输入拆分(或取消连接)成应用主体函数的块。该标识确定分片大小;当输入轴被标识为网格轴时,输入将沿该逻辑轴拆分(取消连接)成等于相应网格轴大小的多个片段。(如果相应的网格轴大小不能均匀地分割输入数组轴大小,则会出错。)如果输入的 pspec 没有提到网格轴名称,则不会在该网格轴上进行拆分。例如
mesh = jax.make_mesh((4, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)
在这里,由于输入 pspec 没有提到网格轴名称 'j'
,因此没有在网格轴上拆分任何输入数组轴;类似地,由于输入数组的第二个轴未被标识为(因此未在)任何网格轴上拆分,因此应用 f1
会获得沿该轴输入的完整视图。
当在输入 pspec 中未提及网格轴时,我们可以始终重写为效率较低的程序,其中提及所有网格轴,但调用者执行 jnp.tile
,例如
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
换句话说,因为每个输入 pspec 可以提及每个网格轴名称零次或一次,而不是必须提及每个名称恰好一次,我们可以说,除了内置于其输入的 jnp.split
之外,shard_map
在其输入中还内置了 jnp.tile
,至少在逻辑上是这样(尽管根据参数的物理分片布局,可能不需要实际执行平铺)。要使用的平铺不是唯一的;我们也可以沿着第一个轴平铺,并使用 pspec P(('j', 'i'), None)
。
输入端可能会发生物理数据移动,因为每个设备都需要拥有适当数据的副本。
使用 out_specs
控制如何通过串联、块转置和反平铺来组装每个输出#
与输入端类似,每个 out_specs
通过名称将相应的输出数组的一些轴与网格轴关联起来,表示应该如何将输出块(主体函数的每次应用一个,或者等效地每个物理设备一个)重新组装在一起以形成最终输出值。例如,在上面的 f1
和 f2
示例中,out_specs
指示我们应该沿着两个轴将块结果连接在一起,从而形成最终输出,在两种情况下都得到一个形状为 (12, 24)
的数组 y
。(如果主体函数的输出形状(即输出块形状)的秩太小,无法满足相应输出 pspec 描述的连接,则会出错。)
当输出 pspec 中未提及网格轴名称时,它表示反平铺:当用户编写未提及网格轴名称之一的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此在输出中仅使用沿该轴的一个块(而不是沿该网格轴将所有块连接在一起)。例如,使用与上面相同的网格
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
封闭数组值的主体函数等效于将其作为参数传递,其对应的输入 pspec 为 P(None, None)。作为另一个示例,更接近上面的其他示例
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)
结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum
,在输出 pspec 中未提及网格轴名称 'j'
所表达的反平铺是安全的,这确保每个输出块沿相应的网格轴相等。以下是另外两个示例,我们改变了输出 pspec 中提及的网格轴
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
(3, 12)
(3, 6)
在物理方面,在输出 pspec 中未提及网格轴名称会从输出设备缓冲区组装一个 Array
,该缓冲区沿该网格轴具有复制的布局。
没有运行时检查来确保输出块实际上沿要反平铺的网格轴相等,或者等效地,相应的物理缓冲区具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一个静态检查机制,该机制会在所有可能不正确的程序上引发错误。
因为 out_specs
可以提及网格轴名称零次或一次,并且因为它们可以以任何顺序提及,我们可以说,除了内置于其输出的 jnp.concatenate
之外,shard_map
在其输出中还内置了*反平铺*和*块转置*。
无论输出 pspec 如何,都不可能在输出上进行物理数据移动。相反,out_specs
只是编码如何将块输出组装成 Array
,或者从物理上讲,如何将设备上的缓冲区解释为单个逻辑 Array
的物理布局。
API 规范#
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
check_rep: bool = True,
) -> Callable:
...
其中
f
主体中的通信集合(如psum
)可以提及mesh
的轴名称;mesh
对以数组形式排列并具有相关轴名称的设备进行编码,就像它对sharding.NamedSharding
所做的那样;in_specs
和out_specs
是PartitionSpec
,可以仿射地提及mesh
中的轴名称,分别表达输入和输出的切片/取消连接和连接,其中未提及的名称分别对应于复制和反平铺(断言-复制-因此-给我-一份副本);auto
是一个可选的轴名称集合,对应于mesh
的名称的子集,用于在主体中自动处理,就像在调用方中一样,而不是手动处理;check_rep
是一个可选的布尔值,指示是否静态检查out_specs
中是否存在任何复制错误,以及是否启用相关的自动微分优化(请参阅 JEP)。
传递给 f
的参数的形状与传递给 shard_map
-of-f
的参数的形状具有相同的秩,并且 f
的参数的形状是从 shard_map
-of-f
的相应参数的形状 shape
和相应的 PartitionSpec
spec
计算得出的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
。
集合教程#
shard_map
不必是纯映射:函数应用程序可以使用在 mesh
参数中定义的轴名称,通过*集合*相互通信。
回想一下,shard_map
将一个函数映射到输入数据的分片或块上,因此
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
计算相同的值,对相同的参数值评估 f
的应用程序,就像此参考函数一样
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
我们将 f
对不同参数分片的这些应用程序称为*函数实例*。每个函数实例都在不同的设备(或设备子集)上执行。
当 f
中没有通信集合时,这些参考语义起作用。但是,如果我们希望函数实例进行通信,从而对应于跨设备通信,该怎么办?也就是说,当 f
包含集合时,参考语义是什么?假设 f
只有一个集合,并且形式为
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
其中,我们假设我们仅映射一个网格轴,并且 axis_name
是其对应的名称。那么参考语义将更像
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
请注意,collective_ref
可能取决于所有 z_blocks
。也就是说,虽然 f_part1
和 f_part2
是独立映射到块上的,但集合引入了某种程度的跨块依赖关系。在物理上,这意味着跨设备的通信。发生的确切通信和计算的值取决于集合。
psum
#
最简单的集合可能是 jax.lax.psum
,它沿设备网格轴(或多个轴)计算所有简化求和。这是一个玩具示例
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
打印结果表明,每个函数应用程序都从其自己的参数值 x_block
块开始。在 psum
之后,每个函数应用程序都具有相同的 y_block
值,该值是通过将应用程序的 x_block
值相加而计算得出的。
在计算中只有一个轴名称的情况下,我们可以说 psum
的 collective_ref
参考实现为
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
还要注意,因为 f1
返回 y_block
,即在 'i'
上进行 psum
的结果,所以我们可以使用 out_specs=P()
,以便调用方获得结果值的单个逻辑副本,而不是平铺的结果。
当有多个网格轴时,我们可以分别在每个轴上执行 psum
,也可以一次在多个轴上执行 psum
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
通过对网格轴 'i'
应用 psum
,我们得到的 y_block
值在轴 'i'
上相等,但在轴 'j'
上不相等。(因此,我们可以使用 out_specs=P(None, 'j')
来获得沿该轴的单个逻辑结果。)
如果我们对两个轴都应用 psum
,则 y_block
值在两个轴上都相等。
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
在机器学习中,我们经常使用 psum
来计算总损失,或者当我们在 shard_map
映射的函数体内有 grad
时,计算总梯度。
接下来,我们将看到如何用其他原语来实现 psum
,这能提供一些关于其通信成本的直观理解。
all_gather
#
另一个基本操作是沿轴收集数组分片,以便每个函数应用程序都有沿该轴的完整数据副本。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
打印输出显示,每个函数应用程序再次从其自身的参数值 x_block
块开始。在 all_gather
之后,它们具有一个公共值,该值通过连接 x_block
的值计算得出。
(请注意,我们实际上不能在这里设置 out_specs=P()
。出于与自动微分相关的技术原因,我们认为 all_gather
的输出不能保证在设备之间是不变的。如果我们想要保证它不变,我们可以使用 jax.lax.all_gather_invariant
,或者在这种情况下,我们可以避免在函数体中执行 all_gather
,而只需使用 out_specs=P('i')
来执行连接。)
当 tiled=False
(默认值)时,结果是沿新轴堆叠而不是连接的。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
我们可以将 all_gather
的 collective_ref
引用语义函数写成
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
在深度学习中,我们可能会在完全分片数据并行 (FSDP) 中对参数使用 all_gather
。
psum_scatter
#
jax.lax.psum_scatter
集合操作有点不太直观。它类似于 psum
,只是每个函数实例只获得结果的一个分片。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
如打印输出所示,每个生成的 y_block
的大小都小于参数 x_block
,这与 psum
不同。此外,与 psum
相比,这里的每个 y_block
仅表示函数实例中 x_block
总和的切片。(即使每个函数实例只获得总和的一个分片,最终输出 y
与 psum
示例中的相同,因为这里我们使用 out_specs=P('i')
来连接每个函数实例的输出。)
就计算的值而言,collective_ref
引用实现可能如下所示:
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
语义引用实现中没有捕获这一点,但 psum_scatter
是有用的,因为这些结果可以更高效地计算,通信量比完整的 psum
少。实际上,考虑 psum_scatter
的一种方式是将其视为“psum
的前半部分,在 all_gather
之前”。也就是说,实现 psum
的一种方法是
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
事实上,这种实现经常在 TPU 和 GPU 上使用!
psum_scatter
需要大约一半的通信量才能完成完整的 psum
的原因是 ppermute
部分中说明的。
另一种直观的理解是,我们可以使用 psum_scatter
来实现分布式矩阵乘法,其中输入和输出都沿同一轴分片。在机器学习中,psum_scatter
可用于张量并行矩阵乘法或完全分片的数据并行梯度累积,如下面的示例所示。
ppermute
#
jax.lax.ppermute
集合操作为函数实例提供了一种最直接的方式来相互发送数据。给定一个网格轴和一个表示沿该网格轴的索引的 (source_index, destination_index)
对的列表,ppermute
将其参数值从每个源函数实例发送到每个目标。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.psum(1, 'i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
在这种情况下,只有两个函数实例,每个实例的 y_block
值是另一个实例的 x_block
值。
源索引和目标索引不能重复。如果某个索引没有作为目标出现,则对应函数实例的结果值是一个零数组。
collective_ref
引用实现可能如下所示:
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
可以使用 ppermute
高效地实现其他集合操作,其中每个函数仅将数据传递给其相邻函数,从而实现总通信量。例如,我们可以使用一系列 ppermute
和局部加法来实现 psum_scatter
,如下所示:
或者,用一个数值示例:
直观地说,在每次迭代中,每个函数实例都会“向上”发送它在前一次迭代中收到的值,并减少(添加)它在这次迭代中收到的值。在代码中,它可能如下所示:
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.psum(1, axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
在 TPU 上,此算法具有更高维度的变体,以利用多个双向物理网格轴。
请注意,psum_scatter
是 all_gather
的转置。实际上,使用 ppermute
实现 all_gather
的一种方法看起来像是上述过程的反向操作:
在深度学习中,我们可能会在实现 SPMD 流水线并行时使用 ppermute
,我们将网络沿其深度划分为多个阶段,并并行评估各阶段的应用。或者,我们可能会在并行化卷积层的评估时使用 ppermute
,我们沿空间轴进行分片,因此设备必须相互传递“光晕”。或者,它可以在幕后用于张量并行矩阵乘法。
all_to_all
#
最后一个集合操作是 all_to_all
,它本质上是沿一个位置轴和一个跨设备轴操作的块矩阵转置。
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
split_axis
参数指示哪个位置轴应分片并在网格轴上进行分区。concat_axis
参数指示应沿哪个轴连接或堆叠通信的结果。
当 tiled=False
(默认值)时,split_axis
轴的大小必须等于名为 axis_name
的网格轴的大小,并且在该大小的新轴在位置 concat_axis
上为堆叠结果创建。当 tiled=True
时,split_axis
轴的大小只需能被网格轴的大小均匀整除,并且结果将沿现有轴 concat_axis
连接。
当 split_axis=0
且 concat_axis=0
时的 collective_ref
引用语义可能如下所示:
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
在深度学习中,我们可能会在混合专家路由中使用 all_to_all
,我们首先根据示例应转到的专家对本地批次示例进行排序,然后应用 all_to_all
将示例重新分发给专家。
示例#
我们如何在实践中使用 shard_map
和集合通信?以下示例虽然简单,但也提供了一些思路。
矩阵乘法#
并行化矩阵乘法是扩展深度学习模型的关键,无论是训练还是推理。当 jax.jit
自动并行化矩阵乘法时,它可以根据矩阵大小、硬件细节和其他因素使用几种不同的策略。我们如何使用 shard_map
更明确地编写其中的一些并行化例程?我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:单侧 all-gather
#
考虑执行矩阵乘法,其中我们在其前导(非收缩)维度上对左侧参数(可以认为是参数)进行分片
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
我们在其收缩维度上对右侧参数(可以认为是激活)进行分片,并对输出进行类似的分片
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
要执行此矩阵乘法,我们可以首先全收集右侧,然后针对分片的左侧执行局部矩阵乘法
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
这很好,但我们在这里没有得到任何计算/通信重叠:在我们可以开始 matmul 之前,我们需要 all_gather
完成。这是使用相同代码但在更大的示例形状(lhs
的 (8192, 8192)
和 rhs
的 (8192, 1024)
)上的配置文件
如果不是调用 all_gather
,而是基本上以内联方式使用我们上面 all_gather
的 ppermute
实现,然后将收集置换的步骤与局部矩阵乘法交错,就可以获得计算/通信重叠
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size
lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
out_block = lhs_blocks(idx) @ rhs_block
for i in range(1, size):
rhs_block = shift(rhs_block)
out_block += lhs_blocks((idx - i) % size) @ rhs_block
return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
此实现允许通信和计算之间的重叠,并避免在每个设备上收集大型中间体。但在 TPU 上,它通过仅沿环的一个方向置换来使用一半的互连带宽。要双向置换,我们只需将块分成两半,并沿每个方向发送一半
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[1] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
out_block = lhs_blocks(idx, 0) @ rhs_block_lo
out_block += lhs_blocks(idx, 1) @ rhs_block_hi
for i in range(1, size):
rhs_block_lo = shift_up(rhs_block_lo)
rhs_block_hi = shift_dn(rhs_block_hi)
out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
实际上,为了减少编译时间,我们可能会将其滚入 jax.lax.fori_loop
。我们可能还会涉及其他并行轴。
示例 2:psum_scatter
结果#
我们可能开始的另一种分片方式是沿着它们的收缩维度对 lhs
和 rhs
进行分片,输出再次像 rhs
一样分片
lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)
在这里,我们可以使用 reduce_scatter
来执行分片上的收缩求和
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
out_summand = lhs_block @ rhs_block
return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
但是,散射通信必须等待整个局部矩阵乘法完成后才能开始。为了获得通信/计算重叠,我们可以以内联方式使用 ppermute
的 psum_scatter
实现,然后将通信步骤与局部矩阵乘法交错
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
out_summand = lhs_block[(idx + 1) % size] @ rhs_block
for i in range(1, size):
out_summand = shift(out_summand)
out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
与前面的示例一样,为了充分利用 TPU 上的互连,我们将运行双向版本
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
size = jax.lax.psum(1, 'i')
idx = jax.lax.axis_index('i')
shift_up = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i + 1) % size) for i in range(size)])
shift_dn = partial(jax.lax.ppermute, axis_name='i',
perm=[(i, (i - 1) % size) for i in range(size)])
B = lhs_block.shape[0] // size // 2 # half-size blocks
lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
for i in range(1, size):
out_summand_lo = shift_up(out_summand_lo)
out_summand_hi = shift_dn(out_summand_hi)
out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
神经网络#
我们可以使用 shard_map
来并行化神经网络中的计算,既可以单独使用,也可以与 jax.jit
中的自动分区结合使用。本节包含一些基于此玩具神经网络和随机数据的示例
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
将这些示例与“分布式数组和自动分区”文档中的纯自动分区示例进行比较。虽然在那些自动分区示例中,我们不需要编辑模型函数来使用不同的并行化策略,但使用 shard_map
,我们通常需要这样做。
8 路批数据并行#
最简单的多设备并行策略是将输入和目标批次分片到多个设备上,在这些设备上复制参数,并并行地将模型应用于这些数据分片。为了评估总损失,设备只需要在最后使用标量大小的全归约求和进行通信。(为了评估损失的梯度,设备必须在反向传递中执行参数梯度的全归约求和。)
from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('batch',))
# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(batch)
我们可以检查损失及其梯度是否与参考(基本)模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_dp))(params, batch)))
True
我们可以打印编译器 IR 以检查梯度计算,并验证集体全归约求和操作发生在我们期望的位置:在正向传递的末尾计算损失值,以及在反向传递中计算总参数梯度。
8 路完全分片数据并行 (FSDP)#
另一种策略是在设备上额外分片参数,当 jnp.dot
或偏差相加需要完整值时,全收集每个参数。由于我们一次只有一个完整参数在本地设备内存中,而不是像前面的 DP 示例中那样在所有设备内存中保留所有参数,因此我们释放了大量的内存,可以将其用于更大的模型或更大的批次大小。而且,由于 XLA 会重叠计算和设备间通信,因此挂钟时间不会受到影响。
因此,现在我们需要在两个地方进行集合:模型预测函数 predict
需要在参数使用之前全收集它们,并且与 DP 情况一样,损失函数需要对局部损失求和以计算总损失。
我们还需要另一个要素:我们不希望存储正向传递中完全收集的参数以用于反向传递。相反,我们希望在反向传递中再次收集它们。我们可以使用 jax.remat
和自定义策略(或 custom_vjp
)来表达这一点,尽管 XLA 通常会自动执行该重新实现。
这种通用的FSDP 方法类似于 权重更新分片 (WUS) 和 ZeRO-3。
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
outputs = jnp.dot(inputs, W) + b
inputs = jax.nn.relu(outputs)
return outputs
def loss_fsdp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
def loss_spmd(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp(local_params, inputs)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'batch')
return loss_spmd(params, batch)
同样,我们可以检查损失及其梯度是否与参考模型匹配
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True
8 路张量并行 (TP)#
通常我们不单独使用张量模型并行,但单独查看它对于并行矩阵乘法来说是一个很好的热身。这也是在更大的基于 jit
的计算中调用 shard_map
在库函数中使用的很好的例子。
并行化思想是,我们将数据/激活保持在其特征轴(而不是其批次轴)上的分片,并且我们类似地在其输入特征轴上分片权重矩阵(并在其特征轴上分片偏差)。然后,为了执行并行矩阵乘法,我们将执行局部矩阵乘法,然后执行 psum_scatter
以对局部结果求和并有效地散布结果的分片。
mesh = jax.make_mesh((8,), ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
for W, b in params:
outputs = gemm_tp(inputs, W, b)
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
block_result = jnp.dot(inputs, W)
return jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
inputs, targets = batch
predictions = predict_tp(params, inputs)
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,顶层使用 shard_map
#
我们可以将这些策略组合在一起,使用多个并行轴。
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
for W_frag, b_frag in params_frag:
W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
block_result = jnp.dot(inputs, W)
outputs = jax.lax.psum_scatter(block_result, 'feats',
scatter_dimension=1, tiled=True) + b
inputs = jax.nn.relu(outputs)
return outputs
@partial(shard_map, mesh=mesh,
in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
inputs, targets = local_batch
predictions = predict_fsdp_tp(local_params, inputs)
sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
return jax.lax.pmean(jnp.mean(sq_err), 'batch')
请注意,我们必须进行两次集体归约:一次针对 'feats'
,一次针对 'batch'
。在纯 TP 示例中,我们没有显式地编写 'feats'
归约,因为我们只在 gemm_tp
中使用了 shard_map
;在调用方 loss_tp
中,编译器根据 predict_tp
返回的分片结果,自动将我们对 jnp.sum
的使用转换为根据需要执行 psum
。
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True
SPMD 流水线并行 (PP)#
通过流水线并行,我们的目标是并行化网络中不同深度的层的评估。例如,一个设备可能会计算第一层的应用,而另一个设备可能会计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并且该过程重复进行。通常,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多层。
借助 SPMD 流水线,我们可以利用以下事实:网络中的大多数层都应用计算,只是参数值不同。特别是,我们可以将除第一层和最后一层的参数之外的所有参数堆叠在一起,然后使用 shard_map
来映射这些层参数的块,其中每个参数块都对应一个流水线阶段。然后,我们使用 jax.lax.ppermute
集合体来向下移动并行流水线中的数据。
这种特定的流水线策略本质上是GPipe 策略。有几种变体,以及完全不同的策略,哪种策略合适取决于阶段之间的网络速度和批次大小。但对于本教程,我们将只关注一种策略。
首先,我们选择一些流水线参数
L = len(params) - 2 # num layers, excluding first and last
N = batch_size # batch size
F = params[0][0].shape[1] # num features
# choose some pipeline parameters
S = 2 # number of stages
B = 8 # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B) # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S) # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
(W_first, b_first), inner_params, (W_last, b_last) = params
inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
inner_params, inputs)
outputs = jnp.dot(inputs, W_last) + b_last
return outputs
@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
out_specs=P())
def loss_pp(params, batch):
inputs, targets = batch
predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
stage = jax.lax.axis_index('stages')
outputs = jnp.zeros_like(inputs) * jnp.nan
state = jnp.zeros((L // S, B, F)) * jnp.nan
for i in range(M+L-1):
state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
state = jax.vmap(fn)(stage_params, state)
outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
state, inputs, outputs = shift(i, state, inputs, outputs)
outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
return outputs
def shift(i, state, inputs, outputs):
sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
if (i % K) == (-1 % K):
inputs = sh(inputs, +1)
if ((i-L+1) % K) == (-1 % K):
outputs = sh(outputs, +1)
return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash