复制诱导集体的高效转置#

mattjj@, dougalm@

2023 年 8 月

动机#

我们在自动转置包含某些集体操作的shmap时遇到了效率问题。这个问题出现在psumall_gather中,特别是在集体操作的输出作为未映射输出返回给调用者时。这并不是一个边缘情况:例如,当对使用psum计算总损失的基于shmap的批量数据并行神经网络损失函数应用grad时就会出现这种情况。

我们已经知道这个问题有一段时间了。pmap也存在类似的问题,不过可以通过将grad保留在pmap内部而不是外部来解决。不完整名称aval的工作的主要目标之一就是解决此转置效率问题的某个版本。本文档借鉴了这些想法,同时扩展和修改了它们以处理更多情况并更容易落地。事实上,这里提出的解决方案只影响shmap的实现。系统其余部分无需更改(至少目前不用)。

本文档的主要目的是定义此转置效率问题并提出一个易于落地的解决方案。

本文档不涉及

  • 数组上的逻辑轴名称(这里的轴名称与shmap和原始pmap中的完全相同);

  • 更改自动微分语义(所有数字和(非)错误保持不变,我们只是使事情更高效);

  • 允许用户代码反映任何新信息,或真正影响用户代码。

问题:psumall_gather的高效转置取决于余切是否在设备之间不变#

考虑这个半现实的例子,它旨在类似于一个复制参数的批量数据并行损失函数

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

注意out_specs=P(),它表示未映射输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。

损失示例中的大部分细节并不重要。对我们的目的而言,唯一重要的是我们在最后应用了psum(更确切地说,是pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,简化后的版本如下所示

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

我们甚至通过抑制mesh参数简化了表示法。在后续示例中,可以从上下文中推断出来。

转置是什么样子的?用t表示函数转置,我们可以通过应用以下函数¿f1_transpose?高效地评估任何ybart(f1)(ybar)

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

但这不是我们目前作为t(f1)得到的转置。

相反,当前的转置方法大致是交换in_specsout_specs,对未映射输出进行一些除法重新缩放,并转置主体。因为psum是其自身的转置(作为全归约求和),所以我们最终会产生以下转置

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

此转置得到了正确的数字,但它效率低下。我们从转置的in_specs=P()静态地知道ybar对于每个函数实例都具有相同的值,即对于命名为i的网格轴上的设备,其值对于设备而言是不变的,但我们仍然对其应用了psum!这使用了昂贵的通信,只是为了将每个设备上的值乘以8。(这里的8指的是轴i的大小。除以8来自原始函数的out_specs=P();它和微不足道的psum基本上相互抵消。)

我们做错了什么?我们没有利用与f1的未映射输出相对应的余切ybar保证是设备不变的事实;相反,我们防御性地对其应用了psum,就好像它们不是设备不变的,因为psum的转置在给定其拥有的本地信息的情况下无法确定。有时psum是必要的,例如相对于其第一个参数转置f2

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

直观地说,如果我们的转置机制能够区分示例1和示例2,那么我们就可以通过在可能的情况下避免psum和除法来做得更好。

低效的示例可以更小。考虑转置这个糟糕的恒等函数

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

它在我们进行更多转置时会变得越来越大。真是尴尬!

而且psum并不是唯一的罪魁祸首。all_gather也存在类似的情况

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

此程序有点人为。为什么要执行all_gather并将结果馈送到未映射输出,而不是跳过主体中的all_gather并只使用out_specs=P('i')来收集结果?但是,即使它是人为设计的,此示例仍然展示了一个不必要地执行通信的转置(我们本可以只执行一个非通信切片),类似于psum的示例1。

同样类似于psum示例,防御性psum_scatter在某些情况下是必要的

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

那么我们如何避免这些低效的转置呢?

解决方案#

这里有两个解决方案的想法。它们不是互斥的。但是(剧透)第二个更好,而且是我们唯一需要的。

部分解决方案“P-sum”:构建将psum表达成out_specs的能力#

这个解决方案有点像稻草人,因为它只提供了一种笨拙的编写程序的方式。它甚至无法修复所有问题!但值得考虑,即使只是为了激发更完整的解决方案。

上面的示例4是人为的,因为我们本可以使用out_specs而不是主体中的all_gather

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

f4_better版本没有任何转置问题,因为转置问题是由主体中的集体操作引起的。

类似地,我们可以通过扩展out_specs使其能够表达求和来修复示例1

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

因此,在out_specs中提供内置的psum可以解决示例1的转置问题。但它没有完全修复示例3中的糟糕恒等转置

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

这是一个改进,因为程序不会随着我们不断转置而变得越来越大,但我们仍然在进行浪费的通信。

完整解决方案:静态跟踪设备变化与设备不变的中间变量,以及新的原语#

此解决方案有两个组成部分

  1. 跟踪值在特定网格轴上保证是设备不变还是设备变化,以及

  2. psum分解为两步过程,引入一个新的pbroadcast原语,并为all_gather及其转置引入新的原语。

从道德上讲,设备不变与设备变化信息的跟踪是类型级别的考虑。但为了我们第一个实现的便捷性,我们不需要将信息真正添加到抽象值或jaxpr类型中。在进入实现之前,我们将首先使用类型介绍这个概念。

接下来还会讨论使用户API方便且向后兼容。但为了首先介绍这个概念,我们将忽略方便性,而是编写尽可能明确的代码。

在aval中跟踪设备不变性(又名带名称的aval,复活了)#

我们有时可以仅从静态信息中得知shmap主体中某些中间变量的值保证在网格轴上是不变的,从某种意义上说,网格轴上的函数实例(及其对应的设备)必须全部使用相同的值进行计算。我们将此类值称为设备不变值。对于不是设备不变的值,我们将称之为设备变化值,尽管实际上我们的意思是说从类型系统的角度来看,它们可能是设备变化值。

为了在类型中编码设备变化,我们将扩展数组类型的语法。我们将编写类似x:f32[3,4]{i}的内容来表示x在网格轴i上(可能是)设备变化的(并且在shmap的任何其他网格轴上都是设备不变的)。更一般地,我们将说数组类型语法的语法类似于

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

我们还将更新类型规则以处理设备变化类型

  • 对于除集体操作以外的一阶原语

    • 对于多元基元,操作数设备变体类型必须相等,其中形状必须相等,例如 mul x:f32[s1]{r1} y:f32[s2][r2] 需要 r1 == r2 以及 s1 == s2

    • 输出设备变体类型必须与操作数相同。

  • 对于高阶基元

    • 我们只需实例化任何类型变量,包括设备变体类型(并在检查类型以确保相等时检查它们的设备变体类型是否相等)。

    • (在执行类型推断时,例如对于 cond 的分支,我们采用设备变体类型中轴名称集的并集)。

  • 对于一阶集体操作

    • 集体操作可以接受设备变化或设备不变的输入(沿着与其轴名称参数对应的网格轴);将设备不变的操作数传递给接受设备变化操作数的集体操作,反之亦然,都是错误的。

    • 集体操作可以产生设备变化或设备不变的输出。

    • 请参见下表作为附带好处,实现此类型检查的任何逻辑都可以包含 shmap 的“静态分析”检查,以检查 shmap 体函数是否与任何未映射的 out_specs 兼容。

这是一个总结集体基元设备变体类型的表格。

名称

设备变体类型

示例

降低到 HLO

转置

psum2

Varying -> Invariant

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum(通信)

pbroadcast

pbroadcast

Invariant -> Varying

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

无操作(无通信)

psum

all_to_all

Varying -> Varying

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll(通信)

all_to_all

axis_index

() -> Varying

idx:i32[]{i} = axis_index('i')

ReplicaId 和一些算术运算(无通信)

n/a

psum_scatter

Varying -> Varying

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum(通信)

all_gather

all_gather

Varying -> Varying

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather(通信)

psum_scatter

pscatter

Invariant -> Varying

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None](无通信)

all_gather_invariant

all_gather_invariant

Varying -> Invariant

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather(通信)

pscatter

这里有一些令人惊讶的事情!

  • 我们引入了几个新的基元,包括

    • pbroadcast,有趣的是它降低到无操作。

    • all_gather_invariant,它降低到与 all_gather 相同的结果,但具有不同的设备变体类型(本质上 all_gather 融合了 pbroadcast,而 all_gather_invariant 没有)。

    • pscatterall_gather_invariant 的对偶(转置)。

  • all_gather 具有设备变化的结果。

直观地说,引入 pbroadcast 的原因(除了使类型规则有效之外)是为了使 psum 能够转置为物理无操作。我们需要 all_gather 具有设备变化的结果的原因是,以便我们可以将其转置为 psum_scatter;如果我们改为保留其设备不变的结果,我们可能需要下游的 pbroadcast,并且该组合将转置为效率低下的 psum,然后是切片/ pscatter。因此,我们改为将 pbroadcast“融合”到 all_gather 中,从而允许高效地转置为 psum_scatter。我们提供 all_gather_invariant 及其转置 pscatter 主要用于完整性;用户不太可能需要它(它对应于示例 4 中的情况,使用 out_specs 可以轻松地以不同的方式编写)。

有趣的是,psumpbroadcast 转置对对应于用户在使用 pmap 训练大型语言模型时引入的 psum_idrevid_psumrev

此系统如何解决效率低下的转置示例#

再次考虑简化的激励示例

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

使用这些新规则,转置为

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

其中评估 pbroadcast 应用根本不涉及通信或浮点运算;它是一个无操作。请注意,如果我们继续转置,主体的大小不会增加;实际上 t(t(f1)) == f1。效率得到了提升!

并且我们也不会搞砸其他示例,只要我们 pbroadcast 以使类型检查在需要时进行。

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

直观地说,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。

对于 all_gather 示例,示例 4 需要使用 all_reduce_invariant 才能进行高效的转置(尽管最好改为使用 out_specs 而不是主体中的集体操作)。

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

对于示例 5,使用设备变化的 all_gather 可以按我们期望的方式工作。

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

如何使 API 对用户友好(并向后兼容)#

但是哪个用户想要编写 pbroadcast?哪个开发人员想要破坏大量涉及 psum 的现有用户代码,而这些代码没有馈送到未映射的输出?不是我!

相反,我们可以自动插入 pbroadcast。这有点类似于我们在 jax.numpy 层执行自动秩提升的方式,插入广播以避免二元运算符中的秩不匹配错误。但它要简单得多,因为我们不需要处理形状元组。典型规则是:每当我们看到一个多元操作,其中操作数在其设备变体类型中不一致时,取操作数的设备变体类型轴名称集的并集,并插入 pbroadcast 以将每个操作数提升到结果设备变体类型。

在需要之前自动插入 pbroadcast 可能意味着我们对同一操作数多次应用相同的 pbroadcast,从而创建公共子表达式。当我们进行转置时,这些可能会变成 psum 的总和,而不是 psum 的总和。我们将依靠编译器根据需要清理它。如果这是一个问题,那么我们可以向 pbroadcast 插入过程添加一些简单的记忆化。

用户对 all_gather 的 API 将意味着默认情况下为 all_gather_p(而不是 all_gather_invariant_p),涵盖常见情况,这意味着无需插入 pbroadcast

我们可以在 shmap 上提供一个选项来禁用此 pbroadcast 的自动插入,在这种情况下,用户需要确保类型正确性。对于那些希望明确指出反向传播中 psum 出现在哪里的人来说,此显式选项可能很有吸引力。

如何实现解决方案#

使实现轻量级的关键在于我们不会将这些类型添加到 aval 或 jaxpr 中。至少,目前不会。这可能很昂贵,因为它需要更新 JAX 的其余部分,例如,aval 和 jaxpr 的所有使用者可能都需要处理新类型。我们不会再犯同样的错误!

相反,我们将把这些扩展类型作为 shmap 内部元数据保留下来,就像当前的“out_specs 的复制检查”机制是 shmap 内部的一样。实际上,此解决方案相当于对现有机制进行相对较小的扩展:它已经跟踪了相同的信息;现在我们只是添加了 pbroadcast

我们至少有两个选项可以选择在哪里执行 pbroadcast 插入

  1. 在转置之前,在转置规则中,我们有要转置的计算的 jaxpr;

  2. 在每个 shmap 主体中,无论是在执行时还是分阶段执行,都像当前用于“复制检查 out_specs”的机制一样。前者最终可能更容易,因为我们只需要处理 jaxpr 案例和线性基本运算。但我们将首先尝试后者,以便这里的实现是对现有复制检查逻辑的严格修订/扩展。

附录:定义和说明具有未映射输入和输出的映射#

为了具体起见,我们将主要关注 shmap,尽管这些相同的思想也适用于例如 pmap 和可能 xmap

in_specs 中的相应条目未提及该网格轴的名称时,参数/输入沿网格轴是未映射的。从逻辑上讲,这意味着沿该网格轴的每个函数实例都获得参数的相同值。对于调用者,每个操作数根据映射操作数的网格轴进行切片,而对于未映射操作数的网格轴,则没有切片。

out_specs 中的相应条目未提及该网格轴的名称时,输出沿网格轴是未映射的。从逻辑上讲,这意味着沿该网格轴的每个函数实例必须返回相同的值。对于调用者,shmap 的每个结果都是通过连接沿其映射输出的每个函数实例的返回值形成的,而对于输出未映射的网格轴,仅使用该值的副本。

请参阅 shmap JEP 以了解未映射输入和输出的示例。相比之下,在 vmap 中,未映射的输入/输出通过使用 in_axes / out_axes 的值为 None(而不是 int)来指示。

以下是我们喜欢 shmap 的未映射输入和输出的原因

  • pjit 的表达能力相同。pjit 可以做到的任何事情,shmap 的转义口也应该能够做到。否则,我们的转义口就会有所欠缺!如果 shmap 中没有未映射的输出,那么我们就无法像 pjit 一样表达相同的批量并行损失函数计算。

  • 闭包输入。闭包输入本质上对应于未映射的输入,并且……

  • 在转置下闭包。一旦我们有了未映射的输入,就可以自然地转置到未映射的输出。

因此,未映射的输出既规范又实用!