shmap (shard_map) 用于简单的单设备代码#

sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@

2023 年 1 月

这是建议 shard_map 的设计文档。 您可能需要 最新的用户文档

动机#

JAX 支持两种用于多设备编程的思想流派

  1. 编译器,接管! 让编译器自动将批量数组函数在设备上进行分区。

  2. 让我写出我想要表达的,该死的! 给我单设备代码和显式的通信集合。

我们需要为两者提供出色的 API,并且它们需要相互组合,而不是相互排斥的替代方案。

使用 pjit (现在只是 jit),我们有了 下一代 API 用于第一类问题。但我们还没有完全提升第二类问题的水平。pmap 遵循第二类方法,但随着时间的推移,我们发现它有 致命的缺陷xmap 解决了这些缺陷,但它并没有完全提供按设备划分的形状,而且还包含其他几个重要概念。同时,对按设备显式集合编程的新需求已经出现,例如在 高效扩展 Transformer 推理 中。

我们可以使用 shmap 来提升第二类问题的水平。shmap

  • 一个简单的多设备并行 API,它允许我们编写带有显式集合的按设备代码,其中逻辑形状与按设备物理缓冲区形状匹配,并且集合与跨设备通信完全对应;

  • 一个具有精简功能和一些调整的 xmap 的特殊化版本;

  • XLA SPMD 分区器的“手动”模式的相当直接的呈现;

  • 一个有趣的发音,西式风格的名称,可以代表 shard_mapshpecialized_xmapsholto_mapsharad_map

对于 pjit 用户shmap 是一个补充工具。它可以在 pjit 计算中使用,以临时进入“手动集合”模式,就像一个从编译器自动分区中逃脱的出口。这样,用户可以在大部分代码中使用 pjit 的便利性和熟悉的纯 NumPy 编程模型,并且能够在需要时使用 shmap 手动优化集合通信。这是两全其美!

对于 pmap 用户shmap 是一个严格的升级。它更具表现力、性能更高,并且可以与其他 JAX API 组合使用,而不会使基本批处理数据并行变得更加困难。

有关实际使用的更多信息,您可以跳转到 何时应该使用 shmap,何时应该使用 pjit。如果您想知道为什么我们需要一个新事物,或者 pmap 有什么问题,请跳转到 为什么 pmapxmap 不能解决这个问题?。或者继续阅读下一节,了解一些 shmap 示例和 API 规范。

那么,让我们看看 shmap#

TL;DR 示例(稍后会进行更详细的解释)#

Sho shick

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh((4, 2), ('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 32]
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]

请注意

  • pmap 不同,多个并行轴不需要嵌套(或 axis_index_groups);

  • pmap 和硬 xmap 不同,调用者中不需要重塑,并且逻辑形状对应于按设备的物理形状,这与(非硬)xmap 不同;

  • 通过使用 mesh 精确控制设备放置,这与 pmap 不同;

  • xmap 不同,逻辑和物理只有一个轴名称集;

  • 结果是一个 jax.Array,它可以有效地传递给 pjit,这与 pmap 不同;

  • pmap 不同,此相同代码可以在 pjit/ jit 内有效工作;

  • 此代码可以急切地运行,因此我们可以在中间使用 pdb 并打印值,这与 xmap 的当前实现不同(尽管在设计上,没有顺序调度的 xmap 原则上也可以急切地工作)。

这是另一个带有完全分片结果的 matmul 变体

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
  # c_partialsum: f32[8/X, 32]
  c_partialsum = jnp.matmul(a_block, b_block)
  # c_block: f32[8/X, 32/Y]
  c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
  return c_block

c = matmul_reduce_scatter(a, b)

慢下来,从基础知识开始!#

数组轴上的降秩与保秩映射#

我们可以将 pmap (以及 vmapxmap)视为沿轴展开每个数组输入(例如,将二维矩阵解包成一维行),将其主体函数应用于每个部分,并将结果堆叠在一起,至少在不涉及集合时是这样

pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])

例如,如果 xs 的形状为 f32[8,5],则每个 x 的形状为 f32[5],如果每个 f(x) 的形状为 f32[3,7],则最终堆叠的结果 pmap(f)(xs) 的形状为 f32[8,3,7]。也就是说,主体函数 f 的每次应用都以比 pmap(f) 的相应参数少一个轴的输入作为参数。我们可以说这些是降秩映射,具有输入/输出的解堆叠/堆叠。

f 的逻辑应用程序的数量由映射的输入轴的大小决定:例如,如果我们映射大小为 8 的输入轴,则在语义上我们获得 8 个函数的逻辑应用程序,对于 pmap,这始终对应于 8 个物理计算它们的设备。

相反,shmap 没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“解串联”)成块,应用主体函数,并将结果重新连接在一起(同样,当不涉及集合时)

devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])

回想一下,jnp.split 将其输入切片成大小相等的块,这些块具有相同的秩,因此如果在上面的示例中 y 的形状为 f32[8,5],则每个 y_blk 的形状为 f32[2,5],如果每个 f(y_blk) 的形状为 f32[3,7],则最终连接的结果 shard_map(f, ...)(y) 的形状为 f32[12,7]。因此,shmapshard_map)映射其输入的分片或块。我们可以说它是具有输入/输出的解串联/串联的保秩映射

f 的逻辑应用程序的数量由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为 4 的网格(即,在 4 个设备上),那么在语义上,我们获得 4 个函数的逻辑应用程序,对应于 4 个物理计算它们的设备。

使用 in_specs 控制每个输入如何拆分(解串联)和切片#

每个 in_specs 使用 PartitionSpec 通过名称标识相应输入数组的某些轴和网格轴,表示如何将该输入拆分(或解串联)为应用主体函数的块。这种标识确定了分片大小;当输入轴使用网格轴标识时,输入沿着该逻辑轴拆分(解串联)成数量等于相应网格轴大小的多个部分。(如果相应的网格轴大小不能均匀分割输入数组轴大小,则会发生错误。)如果输入的 pspec 没有提到网格轴名称,则不会在该网格轴上进行拆分。例如

devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))

@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)
  return x_block

x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1)  # prints (3,12)

这里,由于输入 pspec 没有提到网格轴名称 'j',因此没有输入数组轴在该网格轴上拆分;类似地,由于输入数组的第二个轴没有使用任何网格轴标识(因此也没有在该网格轴上拆分),因此 f1 的应用沿该轴获取输入的完整视图。

当输入 pspec 中没有提到网格轴时,我们始终可以重写为效率较低的程序,其中提到所有网格轴,但调用者执行 jnp.tile,例如

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)

换句话说,由于每个输入 pspec 可以零次或一次提及每个网格轴名称,而不是必须精确地提及每个名称一次,因此我们可以说,除了内置于其输入中的 jnp.split 之外,shard_map 在其输入中还内置了 jnp.tile,至少在逻辑上是这样(尽管切片可能不需要物理执行,具体取决于参数的物理分片布局)。使用的切片不是唯一的;我们也可以沿第一个轴切片,并使用 pspec P(('j', 'i'), None)

可以在输入上进行物理数据移动,因为每个设备都需要具有相应数据的副本。

使用 out_specs 控制每个输出如何通过串联、块转置和解切片进行组装#

与输入端类似,每个 out_specs 通过名称将一些相应的输出数组的轴与网格轴关联起来,表示输出块(每个 body 函数的应用对应一个输出块,或者等效地,每个物理设备对应一个输出块)应该如何组合在一起以形成最终输出值。例如,在上面的 f1f2 示例中,out_specs 指示我们应该沿着两个轴将块结果连接在一起以形成最终输出,从而在两种情况下都得到一个形状为 (12,24) 的数组 y。(如果 body 函数的输出形状,即输出块的形状,其秩太小而无法满足相应输出 pspec 描述的连接,则会出错。)

当输出 pspec 中未提及网格轴名称时,它表示取消分片:当用户编写一个未提及其中一个网格轴名称的输出 pspec 时,他们承诺输出块在该网格轴上相等,因此在输出中仅使用该轴上的一个块(而不是沿着该网格轴将所有块连接在一起)。例如,使用与上面相同的网格

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x

请注意,闭包一个数组值的 body 函数等效于将其作为参数传递,并具有相应的输入 pspec P(None, None)。作为另一个示例,更紧密地遵循上面的其他示例

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)  # (12,6)

请注意,结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,由于集体 psum,在输出 pspec 中未提及网格轴名称 'j' 所表示的取消分片是安全的,这确保了每个输出块在相应的网格轴上相等。以下是另外两个示例,我们在其中更改了输出 pspec 中提及的网格轴

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)

在物理方面,在输出 pspec 中未提及网格轴名称会从输出设备缓冲区组装一个沿着该网格轴具有复制布局的 Array

没有运行时检查以确保输出块在要取消分片的网格轴上实际上是相等的,或者等效地,相应的物理缓冲区具有相同的值,因此可以解释为单个逻辑数组的复制布局。但是我们可以提供一种静态检查机制,该机制会在所有可能不正确的程序上引发错误。

由于 out_specs 可以零次或一次提及网格轴名称,并且由于它们可以以任何顺序提及,因此我们可以说,除了其输出中内置的 jnp.concatenate 之外,shard_map 还在其输出中内置了取消分片和块转置。

在输出上不可能进行物理数据移动,无论输出 pspec 如何。相反,out_specs 仅编码如何将块输出组装成 Array,或者在物理上如何将跨设备的缓冲区解释为单个逻辑 Array 的物理布局。

API 规范#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
          ) -> Callable:
  ...

其中

  • mesh 编码以数组形式排列且具有关联轴名称的设备,就像它对 xmapsharding.NamedSharding 所做的那样;

  • in_specsout_specsPartitionSpec,它可以仿射地提及来自 mesh 的轴名称(而不是像 xmap 中那样单独的逻辑名称),以分别表达输入和输出的切片/取消连接和连接(而不是像 pmapxmap 所做的那样取消堆叠和堆叠),其中未提及的名称分别对应于复制和取消分片(断言已复制-因此给我一份副本);

  • 传递给 f 的参数的形状与传递给 shard_map-of-f 的参数的秩相同(与秩被减少的 pmapxmap 不同),并且传递给 f 的参数的形状是根据 shard_map-of-f 的相应参数的形状 shape 和相应的 PartitionSpec 规范计算出来的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

  • f 的主体可以使用来自 mesh 的名称应用集体操作。

shmap 默认是立即执行的,这意味着我们逐个原语地调度计算,以便用户可以在完全复制的值上使用 Python 控制流和交互式 pdb 调试来打印任何值。要将 shmap 映射的函数分段并进行端到端编译,只需在其周围放置一个 jit。一个结果是 shmap 没有像 xmappmap 当前那样自己的调度和编译路径;它只是 jit 路径。

当它被例如封闭的 jit 分段时,将 shmap 降低为 StableHLO 是微不足道的:它只涉及在输入上切换到“手动 SPMD 模式”,并在输出上切换回去。(我们目前不打算支持部分手动、部分自动模式。)

与效果的交互与 pmap 相同。

与自动微分的交互也与 pmap 类似(而不是尝试 xmap 所做的新的语义,对应于具有未映射的中间值,因此 gradreduce_axes 以及使 psum 转置为 pbroadcast 而不是 psum)。但是它因此继承了 pmap 中未解决的问题:在某些情况下,与其将 psum 转置为 psum,从而执行与正向传递 psum 相对应的反向传递 psum,不如将反向传递 psum 移动到反向传递中的其他位置,从而利用线性度。许多高级 pmap 用户通过使用 custom_vjp 来实现 psum_idrevid_psumrev 函数来解决这一挑战,但是由于很容易意外地使这些函数不平衡,因此该技术是一种自伤性武器。我们有一些关于如何以更安全的方式提供此功能的想法。

何时应该使用 shmap,何时应该使用 pjit#

一种理念是:在 jit==pjit 中编写程序几乎总是更简单,但是如果程序的给定部分比编译器可以优化的程度低,则可以进入 shmap

一个真实的例子#

以下是在具有 2D 权重收集模式的 Transformer 层传递中 shmap 的外观(论文,第 5 页上的 3.2.3 节)

def matmul_2D_wg_manual(xnorm, q_wi, layer):
  '''Calls a custom manual implementation of matmul_reducescatter'''
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
  # -> (matmul)
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
  # -> (reducescatter over x into X heads, B batches)
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
  with jax.named_scope('q_wi'):
    xnorm = intermediate_dtype(xnorm)
    q_wi = matmul_reducescatter(
        'bte,hed->bthd',
        xnorm,
        params.q_wi,
        scatter_dimension=(0, 2),
        axis_name='i',
        layer=layer)
   return q_wi


import partitioning.logical_to_physical as l2phys

def pjit_transformer_layer(
    hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
    cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
    x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Forward pass through a single layer, returning output, K, V."""

  def my_layer(t, axis=0):
    """Gets the parameters corresponding to a given layer."""
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)

  # 2D: [batch.Z, time, embed.XY]
  x = _with_sharding_constraint(
      x, ('residual_batch', 'residual_time', 'residual_embed'))
  xnorm = _layernorm(x)
  # 2D: [batch, time, embed.X]
  xnorm = _with_sharding_constraint(
      xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
  # jump into manual mode where you want to optimise
  if manual:
    q_wi = shard_map(matmul_2D_wg_manual, mesh
                in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
                          l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
                out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
  else:
    q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
    # 2D: [batch, time, heads.YZX, None]
    q_wi = _with_sharding_constraint(q_wi,
                                   ('post_norm_batch', 'time', 'heads', 'qkv'))
  q = q_wi[:, :, :, :hparams.qkv]
  q = _rope(sin, cos, q)
  # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
  # swiGLU with full d_ff dimension, rather than 2/3 scaled
  wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
  wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
  kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
  k = kv[:, :, 0, :hparams.qkv]
  v = kv[:, :, 0, hparams.qkv:]
  k = _rope(sin, cos, k)

  y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))

  y_mlp = special2.swish2(wi0) * wi1
  # 2D: [batch, time, heads.YZX, None]
  y_mlp = _with_sharding_constraint(y_mlp,
                                    ('post_norm_batch', 'time', 'heads', None))

  y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
  # do the second half of the mlp and the self-attn projection in parallel
  y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
  # 2D: [batch.Z, time, embed.XY]
  y_out = _with_sharding_constraint(
      y_out, ('residual_batch', 'residual_time', 'residual_embed'))
  z = y_out + x
  z = _with_sharding_constraint(
      z, ('residual_batch', 'residual_time', 'residual_embed'))
  return z, k, v

在下面的配置文件中,第一个和第二个 matmul 都被手动降低的版本替换,其中计算(融合)与通信(ppermute)完全重叠!我们正在使用延迟优化变体的一个有趣的提示是,ppmerute 像素是抖动的——因为有两个重叠的 ppermute 同时使用相反的 ICI 轴!

All-to-all 更难重叠,因此被搁置了。

image

为什么 pmapxmap 尚未解决此问题?#

pmap 是我们的第一个多设备并行 API。它遵循按设备代码和显式集体操作的原则。但是它有一些主要缺点,使其不适合当今的程序

  • 映射多个轴需要嵌套的 pmap 嵌套的 pmap 不仅编写起来繁琐,而且难以控制(甚至预测)数据和计算的设备放置,也难以保留数据分片(请参阅接下来的两点)。 如今的程序需要多个并行轴。

  • 无法控制设备放置。 特别是在多个并行轴的情况下,程序员需要控制这些轴如何与硬件资源及其通信拓扑对齐。但是(嵌套的)pmap 无法控制映射的程序实例如何放置在硬件上;只有一个用户无法控制的自动设备顺序。(Gopher 使用 axis_index_groups 和一个未嵌套的 pmap 本质上是一种通过将多个并行轴扁平化为一个轴来解决此问题的技巧。)

  • jit/pjit 的可组合性。 jit-of-pmap 是一个性能陷阱,就像嵌套 pmap 一样,例如 scan-of-pmap,因为从内部 pmap 返回时不会保留分片。为了保留分片,我们需要在 jaxpr 上进行模式匹配,以确保我们正在处理完全嵌套的 pmap,或者仅在 jit 内的 pmap。此外,pjit 在这里没有帮助,因为 pmap 的目标是 XLA 副本,而 pjit 的目标是 XLA SPMD 分区器,并且将两者组合起来很困难。

  • jax.Array 兼容性(因此也包括 pjit 兼容性)。 由于 pmap 的输出分片不能表示为 Shardings / OpShardings,因为 pmap 的语义是堆叠而不是连接,所以 pmap 计算的输出目前无法传递给 pjit 计算,而不会反弹到主机(或调度一个重塑计算)。

  • 多控制器语义(因此也包括 pjit 兼容性)。 多控制器 pmap 将值跨控制器连接起来,这很好用,但与单控制器 pmap 的堆叠语义不同。更实际的是,它排除了使用非完全可寻址的 jax.Array 输入和输出,正如我们在多控制器 pjit 中所使用的那样。

  • 急切模式。 我们没有将 pmap 设计为首先是急切的,尽管我们最终(在 4 年多之后!)添加了使用 disable_jit() 的急切操作,但 pmap 中融合了 jit 意味着它有自己的编译和分发路径(实际上是两条分发路径:在 Python 中用于处理 Tracer,在 C++ 中用于原始 Array 输入的性能!),这是一个沉重的实现负担。

  • 需要在调用者中进行重塑。 在 8 个设备上使用 pmap 的典型用例可能看起来像从大小为 128 的批处理轴开始,将其重塑为拆分为两个大小为 (8, 16) 的轴,然后对第一个轴进行 pmap。这些重塑很笨拙,并且编译器通常将它们解释为复制而不是视图 - 增加了内存和时间使用量。

当只进行批处理数据并行时,这些缺点并不算太糟。但是当涉及更多并行性时,pmap 就无法胜任了!

xmappmap 的下一代演进铺平了道路,并解决了(几乎)所有这些问题。shmap 紧随 xmap 的脚步,并以基本相同的方式解决了这些问题;实际上,shmap 就像 xmap 的一个专门子集(有些人称之为 “hard xmap” 子集),进行了一些调整。

对于初始原型,我们选择将 shmap 作为与 xmap 分开的原始类型来实现,因为限制其支持的功能集可以更容易地专注于核心功能。例如,shmap 不允许未映射的中间值,因此更容易不用担心命名轴和自动微分之间的交互。此外,不必推断所有成对功能的交互,可以更容易地添加超出今天在 xmap 中实现的功能,例如对急切模式的支持。

shmapxmap 都共享降低代码的很大一部分。我们将来可以考虑合并两者,甚至仅专注于 shmap,这取决于用法如何演变。