使用 shard_map 手动并行化#

概述#

shard_map 是一个单程序多数据 (SPMD) 多设备并行化 API,用于将函数映射到数据分片上。映射的函数应用或实例通过显式集体通信操作彼此通信。

shard_mapjit 内置的自动编译器并行化的补充,并且可以与其组合使用。使用 jit,您可以像编写单个设备的代码一样编写代码,并且 编译器可以自动将计算划分到多个设备上,在幕后生成每个设备的代码和通信集合。使用 shard_map,您可以掌控一切,编写自己的分区代码和显式集合。或者,您可以混合使用两种方法:跨设备组进行手动控制,同时将组内设备分区留给编译器。这两种方法可以根据需要混合、匹配和组合。

如果您熟悉 pmap,您可以将 shard_map 视为其演进。它更具表现力、性能更高,并且可以与其他 JAX API 组合使用。它甚至可以按需执行,方便调试!(有关更多信息,请参阅 pmap 的详细比较。)

通过阅读本教程,您将学习如何使用 shard_map 来完全控制您的多设备代码。您将详细了解它如何与 jax.jit 的自动并行化和 jax.grad 的自动微分组合使用。我们还将提供神经网络并行化策略的一些基本示例。

我们将假设本教程在具有八个设备的环境中运行。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

所以,让我们看看 shard_map#

不多说,这里有一个玩具示例。

from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]

此函数通过执行局部块矩阵乘法,然后进行集合求和操作,来并行计算矩阵乘法。我们可以检查结果是否正确。

from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))
True

结果沿着其行进行分片。

jax.debug.visualize_array_sharding(c)
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

从总体上看,shard_map 类似于 vmappmap,因为我们正在将函数映射到数组数据的各个部分,但请注意

  • shard_map 将输入切分成块(输出通过连接结果块形成),保持等级不变,而 vmap 会通过映射一个轴来降低等级;

  • mesh 参数允许我们控制计算和结果的精确设备放置;

  • 我们同时映射多个数据轴,并为集合设置多个轴名称(这里都是 'x''y');

  • 由于我们还没有使用 jax.jit,因此所有内容都是按需执行的,我们甚至可以 print 中间值以进行调试。

上面的代码执行与以下 jax.jit 自动并行化代码相同的计算。

from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True

我们可以将 shard_map 视为根据其 meshin_specs 参数对输入执行 device_putwith_sharding_constraint,因此 matmul_basic 操作的块与 matmul_reference 中的相同。

print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
                                                  
          CPU 0                    CPU 1          
                                                  
                                                  
          CPU 2                    CPU 3          
                                                  
                                                  
          CPU 4                    CPU 5          
                                                  
                                                  
          CPU 6                    CPU 7          
                                                  
           
           
CPU 0,2,4,6
           
           
           
           
           
CPU 1,3,5,7
           
           
           
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

慢点,从基础开始!#

降维映射与保留维映射#

我们可以将 vmappmap 视为沿轴展开每个数组输入(例如,将 2D 矩阵展开为其 1D 行),将主体函数应用于每个部分,并将结果堆叠在一起,至少在不涉及集合的情况下如此。

def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))

check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True

例如,如果 xs 的形状为 f32[8,5],那么每个 x 的形状为 f32[5],如果每个 f(x) 的形状为 f32[3,7],那么最终堆叠的结果 vmap(f)(xs) 的形状为 f32[8,3,7]。也就是说,主体函数 f 的每次应用都将 vmap(f) 的对应参数作为输入,但少一个轴。我们可以说这些是输入/输出展开/堆叠的降维映射

f 的逻辑应用次数,或 f实例,由要映射的输入轴的大小决定:例如,如果我们映射一个大小为 8 的输入轴,那么从语义上讲,我们将获得 8 个函数的逻辑应用。

相反,shard_map 没有这种降维行为。相反,我们可以将其视为沿着输入轴切片(或“拆分连接”)成块,应用主体函数,并将结果连接在一起(同样在不涉及集合的情况下)。

import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))

check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True

回想一下,jnp.split 将其输入切分成具有相同等级的等大小块,因此,如果在上面的示例中,y 的形状为 f32[8,5],那么每个 y_blk 的形状为 f32[2,5],如果每个 f(y_blk) 的形状为 f32[3,7],那么最终连接的结果 shard_map(f, ...)(y) 的形状为 f32[12,7]。因此,shard_map 将其输入映射到分片或块上。我们可以说它是输入/输出拆分连接/连接的保留维映射

函数 f 的逻辑应用次数由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即,跨 4 个设备),那么从语义上讲,我们将获得 4 个函数的逻辑应用,对应于物理上执行它们的 4 个设备。

使用 in_specs 控制每个输入如何拆分(拆分连接)和拼接#

每个 in_specs 使用 PartitionSpec 通过名称将对应输入数组的某些轴与网格轴标识起来,表示如何将该输入拆分(或拆分连接)成应用主体函数的块。该标识确定了分片大小;当输入轴与网格轴标识时,输入将沿着该逻辑轴拆分(拆分连接)成与对应网格轴大小相同的多个部分。(如果对应网格轴大小不能被输入数组轴大小整除,则会报错。)如果输入的 pspec 没有提及网格轴名称,那么不会沿着该网格轴进行拆分。例如

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)  # prints (3, 12)
  return x_block

x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)

这里,因为输入 pspec 没有提及网格轴名称 'j',因此没有输入数组轴沿着该网格轴拆分;类似地,因为输入数组的第二个轴没有与任何网格轴标识(因此也没有沿着其拆分),所以应用 f1 将获得沿着该轴的输入的完整视图。

当输入 pspec 中没有提及网格轴时,我们始终可以将其改写为效率较低的程序,其中所有网格轴都被提及,但调用者执行 jnp.tile,例如

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)
(3, 12)

换句话说,因为每个输入 pspec 可以将每个网格轴名称提及零次或一次,而不是必须精确提及每个名称一次,所以我们可以说,除了其输入中内置的 jnp.split 之外,shard_map 的输入中还内置了 jnp.tile,至少在逻辑上是如此(尽管拼接可能不需要物理执行,具体取决于参数的物理分片布局)。要使用的拼接不是唯一的;我们也可以沿着第一个轴拼接,并使用 pspec P(('j', 'i'), None)

输入可能需要进行物理数据移动,因为每个设备都需要拥有适当数据的副本。

使用 out_specs 控制每个输出如何通过连接、块转置和拼接解除拼接#

与输入端类似,每个 out_specs 通过名称将对应输出数组的某些轴与网格轴标识起来,表示应该如何将输出块(每个主体函数应用一个,或者等效地,每个物理设备一个)重新组装在一起以形成最终的输出值。例如,在上面 f1f2 示例中,out_specs 表明我们应该通过沿着两个轴将块结果连接在一起,来形成最终输出,在两种情况下都得到形状为 (12, 24) 的数组 y。(如果主体函数的输出形状(即输出块形状)的等级小于对应输出 pspec 描述的连接所需要的等级,则会报错。)

当输出 pspec 中没有提及网格轴名称时,它表示一个拼接解除拼接:当用户编写没有提及任何网格轴名称的输出 pspec 时,他们承诺输出块沿着该网格轴是相等的,因此沿着该轴仅使用一个块(而不是沿着该网格轴将所有块连接在一起)。例如,使用与上面相同的网格

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]
[[3.]
 [3.]
 [3.]
 [3.]]
[[3.]]

将数组值传递给函数体的闭包等效于将其作为增量传递,并使用相应的输入 pspec 为 P(None, None)。 另一个例子,更紧密地遵循上面其他例子

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)

结果的第二个轴大小为 6,是输入第二个轴大小的一半。 在这种情况下,由于集体 psum,未在输出 pspec 中提及网格轴名称 'j' 所表达的取消平铺是安全的,这确保了每个输出块沿相应的网格轴相等。 以下还有两个例子,我们改变了输出 pspec 中提到的网格轴

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)
(3, 12)
(3, 6)

在物理方面,在输出 pspec 中不提及网格轴名称会从输出设备缓冲区中组装一个 Array,该缓冲区沿该网格轴具有复制布局。

没有运行时检查以确保输出块在取消平铺的网格轴上实际上相等,或者等效地确保相应的物理缓冲区具有相等的值,因此可以解释为单个逻辑数组的复制布局。 但是,我们可以提供一种静态检查机制,该机制会在所有可能不正确的程序上引发错误。

由于 out_specs 可以零次或一次提及网格轴名称,并且它们可以以任何顺序提及,因此除了 jnp.concatenate 内置于其输出之外,我们还可以说 shard_map 的输出中还包含取消平铺块转置

无论输出 pspec 如何,物理数据移动在输出上都是不可能的。 相反,out_specs 只是编码如何将块输出组装成 Array,或者物理上如何解释跨设备的缓冲区作为单个逻辑 Array 的物理布局。

API 规范#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(
    f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
    auto: collections.abc.Set[AxisName] = frozenset([]),
    check_rep: bool = True,
) -> Callable:
  ...

其中

  • f 函数体中的 psum 这样的通信集体可以提及 mesh 的轴名称;

  • mesh 编码排列成数组的设备以及关联的轴名称,就像它对 sharding.NamedSharding 一样;

  • in_specsout_specsPartitionSpec,它们可以仿射地提及来自 mesh 的轴名称,以分别表达输入和输出的切片/取消连接和连接,未提及的名称分别对应于复制和取消平铺(断言复制,因此给我一个副本);

  • auto 是一个可选的轴名称集,对应于 mesh 的名称子集,这些名称将在函数体中自动处理,就像在调用者中一样,而不是手动处理;

  • check_rep 是一个可选的布尔值,指示是否在静态上检查 out_specs 中的任何复制错误,以及是否启用相关的自动微分优化(参见 JEP)。

传递给 f 的参数的形状与传递给 shard_map-of-f 的参数的形状具有相同的秩,并且传递给 f 的参数的形状是从相应的参数的形状 shape 计算得出的,传递给 shard_map-of-f 以及相应的 PartitionSpec spec,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

集体教程#

一个 shard_map 不必是纯粹的映射:函数应用可以使用在 mesh 参数中定义的轴名称,通过集体相互通信。

回想一下,shard_map 将函数映射到输入数据的碎片或块上,因此

mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)

计算相同的值,评估 f 对相同参数值的应用,就像这个参考函数一样

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape['i'])
  y_blocks = [f(x_blk) for x_blk in x_blocks]
  return jnp.concatenate(y_blocks)

我们将这些 f 对不同参数碎片的应用称为函数实例。 每个函数实例都在不同的设备(或设备子集)上执行。

f 中没有通信集体时,这些参考语义有效。 但是,如果我们希望函数实例进行通信,对应于跨设备通信呢? 也就是说,当 f 包含一个集体时,参考语义是什么? 假设 f 只有一个集体,形式为

def f(x_blk):
  z_blk = f_part1(x_blk)
  u_blk = collective(z_blk, axis_name)
  v_blk = f_part2(x_blk, z_blk, u_blk)
  return v_blk

我们假设只有一个网格轴我们要映射,并且 axis_name 是它的对应名称。 那么参考语义看起来更像

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
  u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
  v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
              in zip(x_blocks, z_blocks, u_blocks)]
  return jnp.concatenate(v_blocks)

请注意,collective_ref 可能取决于所有 z_blocks。 也就是说,虽然 f_part1f_part2 是独立地映射到块上的,但集体引入了一定程度的跨块依赖性。 从物理上讲,这意味着跨设备通信。 到底发生了什么通信,以及计算了什么值,取决于集体。

psum#

最简单的集体可能是 jax.lax.psum,它沿设备网格轴(或多个轴)计算全简并和。 以下是一个玩具示例

Illustration of a psum computation.
import jax
import jax.numpy as jnp
from jax import lax

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]

FINAL RESULT:
 [22 20 12 17]

打印输出显示每个函数应用都从参数值 x_block 的自身块开始。 在 psum 之后,每个函数应用都具有相同的值 y_block,通过将应用的 x_block 值加在一起计算得出。

在计算中只有一个轴名称的情况下,我们可以说 psumcollective_ref 参考实现为

def psum_ref(_, x_blocks):
  tot = sum(x_blocks)
  return [tot] * len(x_blocks)

还要注意,由于 f1 返回 y_block,这是 'i' 上的 psum 的结果,因此我们可以使用 out_specs=P(),以便调用者获得结果值的单个逻辑副本,而不是平铺结果。

当存在多个网格轴时,我们可以分别对每个轴执行 psum,或者对多个轴同时执行 psum

mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block

y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
 [20 22]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
 [20 22]]

FINAL RESULT:
 [[ 8 10 12 14]
 [16 18 20 22]]

通过对网格轴 'i' 应用 psum,我们得到 y_block 的值,这些值在轴 ‘i' 上相等,但在轴 'j' 上不相等。 (因此我们可以使用 out_specs=P(None, 'j') 来获得沿该轴的单个逻辑结果。)

如果我们将 psum 应用于这两个轴,则 y_block 值在这两个轴上都相等

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, ('i', 'j'))
  print('AFTER:\n', y_block)
  return y_block

y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
 [36 40]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
 [36 40]]

FINAL RESULT:
 [[20 24]
 [36 40]]

在机器学习中,我们经常使用 psum 来计算总损失,或者当我们在 shard_maped 函数体中具有 grad 时,计算总梯度。

在后面的内容中,我们将看到如何用其他原语实现 psum,这会对它的通信成本提供一些直觉。

all_gather#

另一个基本操作是沿一个轴收集数组碎片,以便每个函数应用都拥有该轴上数据的完整副本

Illustration of an all_gather computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]

FINAL RESULT:
 [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]

打印输出显示每个函数应用都从参数值 x_block 的自身块开始。 在 all_gather 之后,它们具有一个公共值,该值通过连接 x_block 的值计算得出。

(请注意,我们实际上不能在这里设置 out_specs=P()。 由于与自动微分相关的技术原因,我们认为 all_gather 的输出不能保证在各个设备之间保持不变。 如果我们希望它保证保持不变,我们可以使用 jax.lax.all_gather_invariant,或者在这种情况下,我们可以简单地避免在函数体中执行 all_gather,而是使用 out_specs=P('i') 来执行连接。)

tiled=False(默认值)时,结果会沿新的轴堆叠而不是连接

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
  print('AFTER:\n', y_block)
  return y_block

y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
 [9]
 [5]
 [2]]

FINAL RESULT:
 [[3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]]

我们可以将 all_gathercollective_ref 参考语义函数编写为

def all_gather_ref(_, x_blocks, *, tiled=False):
  combine = jnp.concatenate if tiled else jnp.stack
  return [combine(x_blocks)] * len(x_blocks)

在深度学习中,我们可能会在完全分片数据并行 (FSDP) 中对参数使用 all_gather

psum_scatter#

集体操作 jax.lax.psum_scatter 稍微不太直观。它类似于 psum,除了每个函数实例只获得结果的一个分片。

Illustration of a psum_scatter computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

如打印所示,每个结果 y_block 的大小都小于参数 x_block,这与 psum 不同。此外,与 psum 相比,这里每个 y_block 只代表了 x_block 在所有函数实例中的总和的切片。(尽管每个函数实例只获得了总和的一个分片,但最终输出 ypsum 示例中的相同,因为这里我们使用 out_specs=P('i') 来连接每个函数实例的输出。)

在计算值的方面,collective_ref 引用实现可能看起来像

def psum_scatter_ref(i, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  tot = sum(x_blocks)
  if tiled:
    tot = tot.reshape(axis_size, -1, *tot.shape[1:])  # split leading axis
  return [tot[i] for i in range(tot.shape[0])]

在语义参考实现中没有捕获到,但 psum_scatter 很有用,因为与完整的 psum 相比,这些结果可以更有效地计算,通信量更少。实际上,我们可以将 psum_scatter 视为“psum 的前半部分,在 all_gather 之前”。也就是说,实现 psum 的一种方法是

def psum(x, axis_name):
  summed_chunk = jax.lax.psum_scatter(x, axis_name)
  return jax.lax.all_gather(summed_chunk, axis_name)

实际上,这种实现通常在 TPU 和 GPU 上使用!

ppermute 部分说明了为什么 psum_scatter 可能需要大约一半的完整 psum 的通信量。

另一个直觉是我们可以使用 psum_scatter 来实现一个分布式矩阵乘法,该矩阵乘法的输入和输出在同一个轴上被分片。在机器学习中,psum_scatter 可用于张量并行矩阵乘法或全分片数据并行梯度累积,如以下示例所示。

ppermute#

集体操作 jax.lax.ppermute 为函数实例提供了一种最直接的方式来互相发送数据。给定一个网格轴和一个 (source_index, destination_index) 对列表,表示该网格轴上的索引,ppermute 将其参数值从每个源函数实例发送到每个目标函数实例。

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
  sz = jax.lax.psum(1, 'i')
  print('BEFORE:\n', x_block)
  y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
  print('AFTER:\n', y_block)
  return y_block

y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]

FINAL RESULT:
 [6 7 0 1 2 3 4 5]

在这种情况下,只有两个函数实例,每个实例的 y_block 值都是另一个实例的 x_block 值。

源索引和目标索引不能重复。如果一个索引没有作为目标出现,则对应函数实例结果的值为零数组。

一个 collective_ref 参考实现可能看起来像

def ppermute_ref(i, x_blocks, perm):
  results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
  for src, dst in perm:
    results[dst] = x_blocks[src]
  return results

其他集体操作可以用 ppermute 来有效地实现(在总通信量方面),每个函数实例只将数据传递给其邻居。例如,我们可以使用一系列 ppermute 和局部加法以这种方式实现 psum_scatter

Illustration of a psum_scatter implementation.

或者,使用数值示例

Illustration of a psum_scatter implementation.

直观地说,在每次迭代中,每个函数实例都会将它在上一次迭代中收到的值“向上”发送,并减少(加)它在这次迭代中收到的值。在代码中,它可能看起来像这样

def psum_scatter(x, axis_name, *, tiled=False):
  size = jax.lax.psum(1, axis_name)
  idx = jax.lax.axis_index(axis_name)  # function instance index along axis_name
  if tiled:
    x = x.reshape(size, -1, *x.shape[1:])  # split leading axis
  shift = partial(jax.lax.ppermute, axis_name=axis_name,
                  perm=[(i, (i - 1) % size) for i in range(size)])
  for i in range(1, size):
    update = shift(x[(idx + i) % size])
    x = x.at[(idx + i + 1) % size].add(update)
  return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
  print('BEFORE:\n', x_block)
  y_block = psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

在 TPU 上,有该算法的高维变体,以利用多个双向物理网格轴。

请注意,psum_scatterall_gather 的转置。实际上,用 ppermute 实现 all_gather 的一种方法看起来像上述过程的逆过程

Illustration of an all_gather implementation.

在深度学习中,我们在实现 SPMD 管道并行时可能使用 ppermute,我们沿着网络的深度将其划分为多个阶段,并并行评估阶段的应用。或者,我们在并行化卷积层的评估时可能使用 ppermute,我们沿着空间轴进行分片,因此设备必须互相通信“光晕”。或者它可能用于张量并行矩阵乘法中。

all_to_all#

最后一个集体操作是 all_to_all,它本质上是在一个位置轴和一个跨设备轴上进行的块矩阵转置。

Illustration of an all_to_all computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
                               tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]

FINAL RESULT:
 [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]

参数 split_axis 指示哪个位置轴应该被分片并在网格轴上进行分区。参数 concat_axis 指示应该沿着哪个轴连接或堆叠通信结果。

tiled=False(默认情况下)时,split_axis 轴大小必须等于名为 axis_name 的网格轴的大小,并且在位置 concat_axis 处为堆叠结果创建一个大小相同的新的轴。当 tiled=True 时,split_axis 轴大小只需要能被网格轴的大小整除,结果就会沿着现有的轴 concat_axis 连接起来。

split_axis=0concat_axis=0 时,collective_ref 参考语义可能看起来像

def all_to_all_ref(_, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  if tiled:
    splits = [jnp.array_split(x, axis_size) for x in x_blocks]
    return [jnp.concatenate(s) for s in zip(*splits)]
  else:
    splits = [list(x) for x in x_blocks]
    return [jnp.stack(s) for s in zip(*splits)]

在深度学习中,我们在专家混合路由中可能使用 all_to_all,我们首先根据每个示例应该去哪个专家对本地批次示例进行排序,然后应用一个 all_to_all 来将示例重新分配给专家。

玩具示例#

我们如何在实践中使用 shard_map 和集体通信?这些示例虽然简单,但给了一些想法。

矩阵乘法#

并行化矩阵乘法是扩展深度学习模型的核心,无论是用于训练还是推理。当 jax.jit 自动并行化矩阵乘法时,它可以使用多种不同的策略,具体取决于矩阵大小、硬件细节和其他因素。我们如何使用 shard_map 更明确地编写一些并行化例程?我们如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))

def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec))

示例 1:在一侧进行 all-gather#

考虑执行一个矩阵乘法,我们沿着左侧参数(可以认为是参数)的开头(非收缩)维度进行分片。

lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)

我们沿着右侧参数(可以认为是激活)的收缩维度进行分片,输出也采用类似的分片方式。

rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)

为了执行这个矩阵乘法,我们可以先对右侧进行全收集,然后对分片的左侧执行局部矩阵乘法。

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

这很棒,但我们在这里没有得到任何计算/通信重叠:在我们可以开始矩阵乘法之前,我们需要 all_gather 完成。这是一个使用相同代码但在更大的示例形状上进行的分析 ((8192, 8192) 用于 lhs 以及 (8192, 1024) 用于 rhs)

Profile of an all-gather matmul without overlap.

如果我们不调用 all_gather,而是基本内联我们上面用 ppermute 实现的 all_gather,然后交错收集置换步骤与局部矩阵乘法,我们就可以获得计算/通信重叠。

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i + 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

这种实现允许通信与计算之间的重叠,并且还避免将一个大的中间结果收集到每个设备上。但在 TPU 上,它只使用一半的互连带宽,因为只沿着环在单方向进行置换。要双向置换,我们只需要将块分成两半,并将每一半分别发送到两个方向。

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
Profile of an all-gather matmul with overlap.

实际上,为了减少编译时间,我们可能会将其整合到 jax.lax.fori_loop 中。我们可能还会涉及额外的并行轴。

示例 2:对结果进行 psum_scatter#

我们可能从开始的另一种分片方式是:lhsrhs 都沿着它们的收缩维度进行分片,输出像 rhs 一样再次进行分片。

lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)

rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)

在这里,我们可以使用 reduce_scatter 对分片进行收缩求和。

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)

out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

但散射通信必须等到整个局部矩阵乘法完成后才能开始。为了获得通信/计算重叠,我们可以内联用 ppermute 实现的 psum_scatter,然后交错通信步骤与局部矩阵乘法。

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis

  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

与前面的示例一样,为了充分利用 TPU 上的互连,我们会运行一个双向版本。

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[0] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)

  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
  return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

神经网络#

我们可以使用 shard_map 来并行化神经网络中的计算,无论是单独使用还是与 jax.jit 中的自动分区结合使用。本节给出了一些基于这个玩具神经网络和随机数据的示例。

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.key(0), layer_sizes, batch_size)

将这些示例与纯 “分布式数组和自动分区”文档中的自动分区示例 进行比较。在那些自动分区示例中,我们不需要编辑模型函数来使用不同的并行化策略,但在 shard_map 中,我们通常需要编辑。

8 路批次数据并行#

最简单的多设备并行策略是将输入和目标批次在多个设备上进行分片,在这些设备上复制参数,并并行地将模型应用于这些分片数据。为了评估总损失,这些设备只需要在最后进行一个标量大小的全简化求和。(为了评估损失的梯度,这些设备必须在反向传播中对参数梯度进行全简化求和。)

from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))

# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

我们可以检查损失及其梯度是否与参考(基础)模型匹配。

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))
True

我们可以打印编译器 IR 来检查梯度计算,并验证集体全简化求和操作发生在预期的地方:在正向传播的末尾计算损失值,以及在反向传播中计算总参数梯度。

8 路全分片数据并行 (FSDP)#

另一种策略是将参数在设备上进行分片,当需要完整的值用于 jnp.dot 或偏差加法时,将每个参数进行全收集。由于我们一次只在一个本地设备内存中有一个完整参数,而不是像前面 DP 示例那样将所有参数保存在所有设备内存中,因此我们释放了大量的内存,可以用于更大的模型或更大的批次大小。而且因为 XLA 会重叠计算和设备间通信,所以时钟时间不会受到影响。

所以现在我们需要在两个地方进行集体操作:模型预测函数 predict 需要在使用参数之前进行全收集,并且与 DP 案例一样,损失函数需要将局部损失求和以计算总损失。

我们还需要另一个成分:我们不想存储正向传播中完全收集的参数以供反向传播使用。相反,我们希望在反向传播中再次收集它们。我们可以通过使用 jax.remat 来表达这一点,该函数具有一个 自定义策略(或 custom_vjp),尽管 XLA 通常会自动进行重新计算。

这种通用的 FSDP 方法 类似于 权重更新分片 (WUS)ZeRO-3

# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch)

同样,我们可以检查损失及其梯度是否与参考模型匹配。

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True

8 路张量并行 (TP)#

通常我们不会单独使用张量模型并行,但单独查看它对于并行矩阵乘法来说是一个很好的热身。这也是在库函数中使用 shard_map 的一个很好的例子,该函数在更大的 jit 基于计算的调用中。

并行化思路是我们会将数据/激活在特征轴上进行分片(而不是在批次轴上),并且我们也会将权重矩阵在输入特征轴上进行分片(以及偏差在特征轴上进行分片)。然后为了执行并行矩阵乘法,我们将执行局部矩阵乘法,然后进行 psum_scatter 来对局部结果求和并有效地散布结果的分片。

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))

def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
         out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats',
                              scatter_dimension=1, tiled=True) + b

def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))  # NOTE psum!

FSDP + TP,在顶层使用 shard_map #

我们可以将这些策略组合在一起,使用多个并行轴。

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))

# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    block_result = jnp.dot(inputs, W)
    outputs = jax.lax.psum_scatter(block_result, 'feats',
                                   scatter_dimension=1, tiled=True) + b
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
         out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
  inputs, targets = local_batch
  predictions = predict_fsdp_tp(local_params, inputs)
  sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
  return jax.lax.pmean(jnp.mean(sq_err), 'batch')

请注意,我们必须进行 *两次* 集体约简:一次在 'feats' 上,一次在 'batch' 上。在纯 TP 示例中,我们没有明确地编写 'feats' 约简,因为我们只在 gemm_tp 中使用 shard_map;在调用者 loss_tp 中,编译器会自动将我们对 jnp.sum 的使用转换为执行一个 psum,根据 predict_tp 返回的分片结果进行调整。

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True

SPMD 管道并行 (PP)#

使用管道并行,我们的目标是并行化网络中不同深度层级的评估。例如,一个设备可以计算第一层的应用,而另一个设备可以计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,并重复此过程。通常,管道阶段的数量可能与层的数量不同,因为每个阶段可能负责多个层。

使用 SPMD 管道,我们利用了网络中大多数层都应用计算这一事实,只是参数值不同。特别是,我们可以将所有参数堆叠在一起,除了第一层和最后一层的参数,然后使用 shard_map 映射这些层参数的块,其中每个参数块对应于一个管道阶段。然后,我们使用 jax.lax.ppermute 集体操作将数据向下移动到并行管道中。

这种特定的管道策略本质上是 GPipe 策略。还有几种变体,以及截然不同的策略,哪种策略适合取决于阶段之间的网络速度和批次大小。但对于本教程,我们将重点关注一种策略。

首先,我们选择一些管道参数。

L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features

# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"

# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))

def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs

@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs

def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params

batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash