使用 shard_map 进行手动并行#

概述#

shard_map 是一个单程序多数据 (SPMD) 多设备并行 API,用于将函数映射到数据分片。映射的函数应用程序或实例通过显式集体通信操作相互通信。

shard_map 与内置于 jit 中的自动编译器并行化是互补的,并且可以组合使用。使用 jit,您可以像为单个设备编写代码一样,并且编译器可以自动将计算分配到多个设备上,在后台生成每个设备的代码和通信集合。使用 shard_map,您可以控制编写自己的分区代码和显式集合。或者,您可以两者都做:在设备组之间进行手动控制,同时将组内设备分区留给编译器处理。这两种方法可以根据需要混合、匹配和组合。

如果您熟悉 pmap,则可以将 shard_map 视为一种演变。它更具表现力、性能更高,并且可以与其他 JAX API 组合使用。它甚至可以急切地工作,以便于调试!(有关更多信息,请参阅pmap 的详细比较。

通过阅读本教程,您将学习如何使用 shard_map 来完全控制您的多设备代码。您将详细了解它如何与 jax.jit 的自动并行化和 jax.grad 的自动微分组合使用。我们还将提供一些神经网络并行化策略的基本示例。

我们将假设本教程在具有八个设备的环境中运行

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

因此,让我们看看 shard_map#

事不宜迟,这是一个玩具示例

from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((4, 2), ('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]

此函数通过执行局部块矩阵乘法,然后执行集体求和操作来并行计算矩阵乘法。我们可以检查结果是否正确

from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))
True

结果沿其行进行分片

jax.debug.visualize_array_sharding(c)
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

从高层次上讲,shard_map 有点像 vmappmap,因为我们正在将一个函数映射到数组数据的片段上,但请注意

  • shard_map 将输入切成块(并且输出是通过连接结果块形成的),保持秩不变,而 vmap 会通过映射掉一个轴来降低秩;

  • mesh 参数让我们控制计算和结果的精确设备放置;

  • 我们正在一次映射多个数据轴,并为集合设置多个轴名称(此处为 'x''y');

  • 由于我们尚未使用 jax.jit,因此所有内容都会被急切地评估,我们甚至可以 print 中间值以进行调试。

上面的代码执行与此 jax.jit 自动并行化代码相同的计算

from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True

我们可以将 shard_map 视为根据其 meshin_specs 参数对其输入执行 device_putwith_sharding_constraint,因此 matmul_basic 操作的块与 matmul_reference 中的块相同

print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
                                                  
          CPU 0                    CPU 1          
                                                  
                                                  
          CPU 2                    CPU 3          
                                                  
                                                  
          CPU 4                    CPU 5          
                                                  
                                                  
          CPU 6                    CPU 7          
                                                  
           
           
CPU 0,2,4,6
           
           
           
           
           
CPU 1,3,5,7
           
           
           
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

放慢速度,从基础知识开始!#

降秩映射与保秩映射#

我们可以将 vmappmap 视为沿轴拆堆每个数组输入(例如,将 2D 矩阵解压为其 1D 行),将其主体函数应用于每个片段,然后将结果堆叠在一起,至少在不涉及集合时是这样

def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))

check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True

例如,如果 xs 的形状为 f32[8,5],则每个 x 的形状将为 f32[5],并且如果每个 f(x) 的形状为 f32[3,7],则最终堆叠结果 vmap(f)(xs) 的形状将为 f32[8,3,7]。也就是说,主体函数 f 的每次应用都以比 vmap(f) 的相应参数少一个轴的输入作为参数。我们可以说这些是具有输入/输出拆堆/堆叠的降秩映射

f 的逻辑应用次数或 f实例由正在映射的输入轴的大小决定:例如,如果我们映射大小为 8 的输入轴,则在语义上我们会得到 8 个函数的逻辑应用。

相比之下,shard_map 没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“取消连接”)成块,应用主体函数,然后将结果连接在一起(同样,在不涉及集合时)

import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))

check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True

回想一下,jnp.split 将其输入切片成大小相等的同秩块,因此如果在上面的示例中 y 的形状为 f32[8,5],则每个 y_blk 的形状将为 f32[2,5],并且如果每个 f(y_blk) 的形状为 f32[3,7],则最终连接结果 shard_map(f, ...)(y) 的形状将为 f32[12,7]。因此,shard_map 映射其输入的分片或块。我们可以说它是一个具有输入/输出取消连接/连接的保秩映射

f 的逻辑应用次数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即在 4 个设备上),那么在语义上我们会得到 4 个函数的逻辑应用,对应于物理计算它们的 4 个设备。

使用 in_specs 控制如何拆分(取消连接)和平铺每个输入#

每个 in_specs 都使用 PartitionSpec 按名称标识相应的输入数组的某些轴与网格轴,表示如何将该输入拆分(或取消连接)成应用主体函数的块。该标识确定分片大小;当输入轴被标识为网格轴时,输入将沿该逻辑轴拆分(取消连接)成等于相应网格轴大小的多个片段。(如果相应的网格轴大小不能均匀地分割输入数组轴大小,则会出错。)如果输入的 pspec 没有提到网格轴名称,则不会在该网格轴上进行拆分。例如

mesh = jax.make_mesh((4, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)  # prints (3, 12)
  return x_block

x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)

在这里,由于输入 pspec 没有提到网格轴名称 'j',因此没有在网格轴上拆分任何输入数组轴;类似地,由于输入数组的第二个轴未被标识为(因此未在)任何网格轴上拆分,因此应用 f1 会获得沿该轴输入的完整视图。

当在输入 pspec 中未提及网格轴时,我们可以始终重写为效率较低的程序,其中提及所有网格轴,但调用者执行 jnp.tile,例如

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)
(3, 12)

换句话说,因为每个输入 pspec 可以提及每个网格轴名称零次或一次,而不是必须提及每个名称恰好一次,我们可以说,除了内置于其输入的 jnp.split 之外,shard_map 在其输入中还内置了 jnp.tile,至少在逻辑上是这样(尽管根据参数的物理分片布局,可能不需要实际执行平铺)。要使用的平铺不是唯一的;我们也可以沿着第一个轴平铺,并使用 pspec P(('j', 'i'), None)

输入端可能会发生物理数据移动,因为每个设备都需要拥有适当数据的副本。

使用 out_specs 控制如何通过串联、块转置和反平铺来组装每个输出#

与输入端类似,每个 out_specs 通过名称将相应的输出数组的一些轴与网格轴关联起来,表示应该如何将输出块(主体函数的每次应用一个,或者等效地每个物理设备一个)重新组装在一起以形成最终输出值。例如,在上面的 f1f2 示例中,out_specs 指示我们应该沿着两个轴将块结果连接在一起,从而形成最终输出,在两种情况下都得到一个形状为 (12, 24) 的数组 y。(如果主体函数的输出形状(即输出块形状)的秩太小,无法满足相应输出 pspec 描述的连接,则会出错。)

当输出 pspec 中未提及网格轴名称时,它表示反平铺:当用户编写未提及网格轴名称之一的输出 pspec 时,他们承诺输出块沿该网格轴相等,因此在输出中仅使用沿该轴的一个块(而不是沿该网格轴将所有块连接在一起)。例如,使用与上面相同的网格

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]
[[3.]
 [3.]
 [3.]
 [3.]]
[[3.]]

封闭数组值的主体函数等效于将其作为参数传递,其对应的输入 pspec 为 P(None, None)。作为另一个示例,更接近上面的其他示例

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)

结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum,在输出 pspec 中未提及网格轴名称 'j' 所表达的反平铺是安全的,这确保每个输出块沿相应的网格轴相等。以下是另外两个示例,我们改变了输出 pspec 中提及的网格轴

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)
(3, 12)
(3, 6)

在物理方面,在输出 pspec 中未提及网格轴名称会从输出设备缓冲区组装一个 Array,该缓冲区沿该网格轴具有复制的布局。

没有运行时检查来确保输出块实际上沿要反平铺的网格轴相等,或者等效地,相应的物理缓冲区具有相等的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一个静态检查机制,该机制会在所有可能不正确的程序上引发错误。

因为 out_specs 可以提及网格轴名称零次或一次,并且因为它们可以以任何顺序提及,我们可以说,除了内置于其输出的 jnp.concatenate 之外,shard_map 在其输出中还内置了*反平铺*和*块转置*。

无论输出 pspec 如何,都不可能在输出上进行物理数据移动。相反,out_specs 只是编码如何将块输出组装成 Array,或者从物理上讲,如何将设备上的缓冲区解释为单个逻辑 Array 的物理布局。

API 规范#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(
    f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
    auto: collections.abc.Set[AxisName] = frozenset([]),
    check_rep: bool = True,
) -> Callable:
  ...

其中

  • f 主体中的通信集合(如 psum)可以提及 mesh 的轴名称;

  • mesh 对以数组形式排列并具有相关轴名称的设备进行编码,就像它对 sharding.NamedSharding 所做的那样;

  • in_specsout_specsPartitionSpec,可以仿射地提及 mesh 中的轴名称,分别表达输入和输出的切片/取消连接和连接,其中未提及的名称分别对应于复制和反平铺(断言-复制-因此-给我-一份副本);

  • auto 是一个可选的轴名称集合,对应于 mesh 的名称的子集,用于在主体中自动处理,就像在调用方中一样,而不是手动处理;

  • check_rep 是一个可选的布尔值,指示是否静态检查 out_specs 中是否存在任何复制错误,以及是否启用相关的自动微分优化(请参阅 JEP)。

传递给 f 的参数的形状与传递给 shard_map-of-f 的参数的形状具有相同的秩,并且 f 的参数的形状是从 shard_map-of-f 的相应参数的形状 shape 和相应的 PartitionSpec spec 计算得出的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

集合教程#

shard_map 不必是纯映射:函数应用程序可以使用在 mesh 参数中定义的轴名称,通过*集合*相互通信。

回想一下,shard_map 将一个函数映射到输入数据的分片或块上,因此

mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)

计算相同的值,对相同的参数值评估 f 的应用程序,就像此参考函数一样

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape['i'])
  y_blocks = [f(x_blk) for x_blk in x_blocks]
  return jnp.concatenate(y_blocks)

我们将 f 对不同参数分片的这些应用程序称为*函数实例*。每个函数实例都在不同的设备(或设备子集)上执行。

f 中没有通信集合时,这些参考语义起作用。但是,如果我们希望函数实例进行通信,从而对应于跨设备通信,该怎么办?也就是说,当 f 包含集合时,参考语义是什么?假设 f 只有一个集合,并且形式为

def f(x_blk):
  z_blk = f_part1(x_blk)
  u_blk = collective(z_blk, axis_name)
  v_blk = f_part2(x_blk, z_blk, u_blk)
  return v_blk

其中,我们假设我们仅映射一个网格轴,并且 axis_name 是其对应的名称。那么参考语义将更像

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
  u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
  v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
              in zip(x_blocks, z_blocks, u_blocks)]
  return jnp.concatenate(v_blocks)

请注意,collective_ref 可能取决于所有 z_blocks。也就是说,虽然 f_part1f_part2 是独立映射到块上的,但集合引入了某种程度的跨块依赖关系。在物理上,这意味着跨设备的通信。发生的确切通信和计算的值取决于集合。

psum#

最简单的集合可能是 jax.lax.psum,它沿设备网格轴(或多个轴)计算所有简化求和。这是一个玩具示例

Illustration of a psum computation.
import jax
import jax.numpy as jnp
from jax import lax

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]

FINAL RESULT:
 [22 20 12 17]

打印结果表明,每个函数应用程序都从其自己的参数值 x_block 块开始。在 psum 之后,每个函数应用程序都具有相同的 y_block 值,该值是通过将应用程序的 x_block 值相加而计算得出的。

在计算中只有一个轴名称的情况下,我们可以说 psumcollective_ref 参考实现为

def psum_ref(_, x_blocks):
  tot = sum(x_blocks)
  return [tot] * len(x_blocks)

还要注意,因为 f1 返回 y_block,即在 'i' 上进行 psum 的结果,所以我们可以使用 out_specs=P(),以便调用方获得结果值的单个逻辑副本,而不是平铺的结果。

当有多个网格轴时,我们可以分别在每个轴上执行 psum,也可以一次在多个轴上执行 psum

mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block

y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
 [20 22]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
 [20 22]]

FINAL RESULT:
 [[ 8 10 12 14]
 [16 18 20 22]]

通过对网格轴 'i' 应用 psum,我们得到的 y_block 值在轴 'i' 上相等,但在轴 'j' 上不相等。(因此,我们可以使用 out_specs=P(None, 'j') 来获得沿该轴的单个逻辑结果。)

如果我们对两个轴都应用 psum,则 y_block 值在两个轴上都相等。

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, ('i', 'j'))
  print('AFTER:\n', y_block)
  return y_block

y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
 [36 40]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
 [36 40]]

FINAL RESULT:
 [[20 24]
 [36 40]]

在机器学习中,我们经常使用 psum 来计算总损失,或者当我们在 shard_map 映射的函数体内有 grad 时,计算总梯度。

接下来,我们将看到如何用其他原语来实现 psum,这能提供一些关于其通信成本的直观理解。

all_gather#

另一个基本操作是沿轴收集数组分片,以便每个函数应用程序都有沿该轴的完整数据副本。

Illustration of an all_gather computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]

FINAL RESULT:
 [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]

打印输出显示,每个函数应用程序再次从其自身的参数值 x_block 块开始。在 all_gather 之后,它们具有一个公共值,该值通过连接 x_block 的值计算得出。

(请注意,我们实际上不能在这里设置 out_specs=P()。出于与自动微分相关的技术原因,我们认为 all_gather 的输出不能保证在设备之间是不变的。如果我们想要保证它不变,我们可以使用 jax.lax.all_gather_invariant,或者在这种情况下,我们可以避免在函数体中执行 all_gather,而只需使用 out_specs=P('i') 来执行连接。)

tiled=False (默认值)时,结果是沿新轴堆叠而不是连接的。

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
  print('AFTER:\n', y_block)
  return y_block

y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
 [9]
 [5]
 [2]]
FINAL RESULT:
 [[3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]]

我们可以将 all_gathercollective_ref 引用语义函数写成

def all_gather_ref(_, x_blocks, *, tiled=False):
  combine = jnp.concatenate if tiled else jnp.stack
  return [combine(x_blocks)] * len(x_blocks)

在深度学习中,我们可能会在完全分片数据并行 (FSDP) 中对参数使用 all_gather

psum_scatter#

jax.lax.psum_scatter 集合操作有点不太直观。它类似于 psum,只是每个函数实例只获得结果的一个分片。

Illustration of a psum_scatter computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
 [22 20 12 17]

如打印输出所示,每个生成的 y_block 的大小都小于参数 x_block,这与 psum 不同。此外,与 psum 相比,这里的每个 y_block 仅表示函数实例中 x_block 总和的切片。(即使每个函数实例只获得总和的一个分片,最终输出 ypsum 示例中的相同,因为这里我们使用 out_specs=P('i') 来连接每个函数实例的输出。)

就计算的值而言,collective_ref 引用实现可能如下所示:

def psum_scatter_ref(i, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  tot = sum(x_blocks)
  if tiled:
    tot = tot.reshape(axis_size, -1, *tot.shape[1:])  # split leading axis
  return [tot[i] for i in range(tot.shape[0])]

语义引用实现中没有捕获这一点,但 psum_scatter 是有用的,因为这些结果可以更高效地计算,通信量比完整的 psum 少。实际上,考虑 psum_scatter 的一种方式是将其视为“psum 的前半部分,在 all_gather 之前”。也就是说,实现 psum 的一种方法是

def psum(x, axis_name):
  summed_chunk = jax.lax.psum_scatter(x, axis_name)
  return jax.lax.all_gather(summed_chunk, axis_name)

事实上,这种实现经常在 TPU 和 GPU 上使用!

psum_scatter 需要大约一半的通信量才能完成完整的 psum 的原因是 ppermute 部分中说明的。

另一种直观的理解是,我们可以使用 psum_scatter 来实现分布式矩阵乘法,其中输入和输出都沿同一轴分片。在机器学习中,psum_scatter 可用于张量并行矩阵乘法或完全分片的数据并行梯度累积,如下面的示例所示。

ppermute#

jax.lax.ppermute 集合操作为函数实例提供了一种最直接的方式来相互发送数据。给定一个网格轴和一个表示沿该网格轴的索引的 (source_index, destination_index) 对的列表,ppermute 将其参数值从每个源函数实例发送到每个目标。

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
  sz = jax.lax.psum(1, 'i')
  print('BEFORE:\n', x_block)
  y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
  print('AFTER:\n', y_block)
  return y_block

y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
 [6 7 0 1 2 3 4 5]

在这种情况下,只有两个函数实例,每个实例的 y_block 值是另一个实例的 x_block 值。

源索引和目标索引不能重复。如果某个索引没有作为目标出现,则对应函数实例的结果值是一个零数组。

collective_ref 引用实现可能如下所示:

def ppermute_ref(i, x_blocks, perm):
  results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
  for src, dst in perm:
    results[dst] = x_blocks[src]
  return results

可以使用 ppermute 高效地实现其他集合操作,其中每个函数仅将数据传递给其相邻函数,从而实现总通信量。例如,我们可以使用一系列 ppermute 和局部加法来实现 psum_scatter,如下所示:

Illustration of a psum_scatter implementation.

或者,用一个数值示例:

Illustration of a psum_scatter implementation.

直观地说,在每次迭代中,每个函数实例都会“向上”发送它在前一次迭代中收到的值,并减少(添加)它在这次迭代中收到的值。在代码中,它可能如下所示:

def psum_scatter(x, axis_name, *, tiled=False):
  size = jax.lax.psum(1, axis_name)
  idx = jax.lax.axis_index(axis_name)  # function instance index along axis_name
  if tiled:
    x = x.reshape(size, -1, *x.shape[1:])  # split leading axis
  shift = partial(jax.lax.ppermute, axis_name=axis_name,
                  perm=[(i, (i - 1) % size) for i in range(size)])
  for i in range(1, size):
    update = shift(x[(idx + i) % size])
    x = x.at[(idx + i + 1) % size].add(update)
  return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
  print('BEFORE:\n', x_block)
  y_block = psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

在 TPU 上,此算法具有更高维度的变体,以利用多个双向物理网格轴。

请注意,psum_scatterall_gather 的转置。实际上,使用 ppermute 实现 all_gather 的一种方法看起来像是上述过程的反向操作:

Illustration of an all_gather implementation.

在深度学习中,我们可能会在实现 SPMD 流水线并行时使用 ppermute,我们将网络沿其深度划分为多个阶段,并并行评估各阶段的应用。或者,我们可能会在并行化卷积层的评估时使用 ppermute,我们沿空间轴进行分片,因此设备必须相互传递“光晕”。或者,它可以在幕后用于张量并行矩阵乘法。

all_to_all#

最后一个集合操作是 all_to_all,它本质上是沿一个位置轴和一个跨设备轴操作的块矩阵转置。

Illustration of an all_to_all computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
                               tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]

FINAL RESULT:
 [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]

split_axis 参数指示哪个位置轴应分片并在网格轴上进行分区。concat_axis 参数指示应沿哪个轴连接或堆叠通信的结果。

tiled=False (默认值)时,split_axis 轴的大小必须等于名为 axis_name 的网格轴的大小,并且在该大小的新轴在位置 concat_axis 上为堆叠结果创建。当 tiled=True 时,split_axis 轴的大小只需能被网格轴的大小均匀整除,并且结果将沿现有轴 concat_axis 连接。

split_axis=0concat_axis=0 时的 collective_ref 引用语义可能如下所示:

def all_to_all_ref(_, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  if tiled:
    splits = [jnp.array_split(x, axis_size) for x in x_blocks]
    return [jnp.concatenate(s) for s in zip(*splits)]
  else:
    splits = [list(x) for x in x_blocks]
    return [jnp.stack(s) for s in zip(*splits)]

在深度学习中,我们可能会在混合专家路由中使用 all_to_all,我们首先根据示例应转到的专家对本地批次示例进行排序,然后应用 all_to_all 将示例重新分发给专家。

示例#

我们如何在实践中使用 shard_map 和集合通信?以下示例虽然简单,但也提供了一些思路。

矩阵乘法#

并行化矩阵乘法是扩展深度学习模型的关键,无论是训练还是推理。当 jax.jit 自动并行化矩阵乘法时,它可以根据矩阵大小、硬件细节和其他因素使用几种不同的策略。我们如何使用 shard_map 更明确地编写其中的一些并行化例程?我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))

def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec))

示例 1:单侧 all-gather#

考虑执行矩阵乘法,其中我们在其前导(非收缩)维度上对左侧参数(可以认为是参数)进行分片

lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)

我们在其收缩维度上对右侧参数(可以认为是激活)进行分片,并对输出进行类似的分片

rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)

要执行此矩阵乘法,我们可以首先全收集右侧,然后针对分片的左侧执行局部矩阵乘法

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

这很好,但我们在这里没有得到任何计算/通信重叠:在我们可以开始 matmul 之前,我们需要 all_gather 完成。这是使用相同代码但在更大的示例形状(lhs(8192, 8192)rhs(8192, 1024))上的配置文件

Profile of an all-gather matmul without overlap.

如果不是调用 all_gather,而是基本上以内联方式使用我们上面 all_gatherppermute 实现,然后将收集置换的步骤与局部矩阵乘法交错,就可以获得计算/通信重叠

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i + 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

此实现允许通信和计算之间的重叠,并避免在每个设备上收集大型中间体。但在 TPU 上,它通过仅沿环的一个方向置换来使用一半的互连带宽。要双向置换,我们只需将块分成两半,并沿每个方向发送一半

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
Profile of an all-gather matmul with overlap.

实际上,为了减少编译时间,我们可能会将其滚入 jax.lax.fori_loop。我们可能还会涉及其他并行轴。

示例 2:psum_scatter 结果#

我们可能开始的另一种分片方式是沿着它们的收缩维度对 lhsrhs 进行分片,输出再次像 rhs 一样分片

lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)

rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)

在这里,我们可以使用 reduce_scatter 来执行分片上的收缩求和

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)

out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

但是,散射通信必须等待整个局部矩阵乘法完成后才能开始。为了获得通信/计算重叠,我们可以以内联方式使用 ppermutepsum_scatter 实现,然后将通信步骤与局部矩阵乘法交错

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis

  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

与前面的示例一样,为了充分利用 TPU 上的互连,我们将运行双向版本

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[0] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)

  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
  return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

神经网络#

我们可以使用 shard_map 来并行化神经网络中的计算,既可以单独使用,也可以与 jax.jit 中的自动分区结合使用。本节包含一些基于此玩具神经网络和随机数据的示例

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.key(0), layer_sizes, batch_size)

将这些示例与“分布式数组和自动分区”文档中的纯自动分区示例进行比较。虽然在那些自动分区示例中,我们不需要编辑模型函数来使用不同的并行化策略,但使用 shard_map,我们通常需要这样做。

8 路批数据并行#

最简单的多设备并行策略是将输入和目标批次分片到多个设备上,在这些设备上复制参数,并并行地将模型应用于这些数据分片。为了评估总损失,设备只需要在最后使用标量大小的全归约求和进行通信。(为了评估损失的梯度,设备必须在反向传递中执行参数梯度的全归约求和。)

from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh((8,), ('batch',))

# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

我们可以检查损失及其梯度是否与参考(基本)模型匹配

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))
True

我们可以打印编译器 IR 以检查梯度计算,并验证集体全归约求和操作发生在我们期望的位置:在正向传递的末尾计算损失值,以及在反向传递中计算总参数梯度。

8 路完全分片数据并行 (FSDP)#

另一种策略是在设备上额外分片参数,当 jnp.dot 或偏差相加需要完整值时,全收集每个参数。由于我们一次只有一个完整参数在本地设备内存中,而不是像前面的 DP 示例中那样在所有设备内存中保留所有参数,因此我们释放了大量的内存,可以将其用于更大的模型或更大的批次大小。而且,由于 XLA 会重叠计算和设备间通信,因此挂钟时间不会受到影响。

因此,现在我们需要在两个地方进行集合:模型预测函数 predict 需要在参数使用之前全收集它们,并且与 DP 情况一样,损失函数需要对局部损失求和以计算总损失。

我们还需要另一个要素:我们不希望存储正向传递中完全收集的参数以用于反向传递。相反,我们希望在反向传递中再次收集它们。我们可以使用 jax.remat自定义策略(或 custom_vjp)来表达这一点,尽管 XLA 通常会自动执行该重新实现。

这种通用的FSDP 方法类似于 权重更新分片 (WUS)ZeRO-3

# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch)

同样,我们可以检查损失及其梯度是否与参考模型匹配

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True

8 路张量并行 (TP)#

通常我们不单独使用张量模型并行,但单独查看它对于并行矩阵乘法来说是一个很好的热身。这也是在更大的基于 jit 的计算中调用 shard_map 在库函数中使用的很好的例子。

并行化思想是,我们将数据/激活保持在其特征轴(而不是其批次轴)上的分片,并且我们类似地在其输入特征轴上分片权重矩阵(并在其特征轴上分片偏差)。然后,为了执行并行矩阵乘法,我们将执行局部矩阵乘法,然后执行 psum_scatter 以对局部结果求和并有效地散布结果的分片。

mesh = jax.make_mesh((8,), ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))

def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
         out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats',
                              scatter_dimension=1, tiled=True) + b

def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))  # NOTE psum!

FSDP + TP,顶层使用 shard_map#

我们可以将这些策略组合在一起,使用多个并行轴。

mesh = jax.make_mesh((4, 2), ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))

# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    block_result = jnp.dot(inputs, W)
    outputs = jax.lax.psum_scatter(block_result, 'feats',
                                   scatter_dimension=1, tiled=True) + b
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
         out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
  inputs, targets = local_batch
  predictions = predict_fsdp_tp(local_params, inputs)
  sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
  return jax.lax.pmean(jnp.mean(sq_err), 'batch')

请注意,我们必须进行两次集体归约:一次针对 'feats',一次针对 'batch'。在纯 TP 示例中,我们没有显式地编写 'feats' 归约,因为我们只在 gemm_tp 中使用了 shard_map;在调用方 loss_tp 中,编译器根据 predict_tp 返回的分片结果,自动将我们对 jnp.sum 的使用转换为根据需要执行 psum

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True

SPMD 流水线并行 (PP)#

通过流水线并行,我们的目标是并行化网络中不同深度的层的评估。例如,一个设备可能会计算第一层的应用,而另一个设备可能会计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并且该过程重复进行。通常,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多层。

借助 SPMD 流水线,我们可以利用以下事实:网络中的大多数层都应用计算,只是参数值不同。特别是,我们可以将除第一层和最后一层的参数之外的所有参数堆叠在一起,然后使用 shard_map 来映射这些层参数的块,其中每个参数块都对应一个流水线阶段。然后,我们使用 jax.lax.ppermute 集合体来向下移动并行流水线中的数据。

这种特定的流水线策略本质上是GPipe 策略。有几种变体,以及完全不同的策略,哪种策略合适取决于阶段之间的网络速度和批次大小。但对于本教程,我们将只关注一种策略。

首先,我们选择一些流水线参数

L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features

# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"

# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))

def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs

@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs

def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params

batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash