高效转置复制诱导集合运算#

mattjj@, dougalm@

2023 年 8 月

动机#

我们在自动转置包含某些集合操作的 shmap 时存在效率问题。这个问题出现在 psumall_gather 上,特别是当集合操作的输出作为未映射的输出返回给调用者时。而且这并非边缘情况:例如,当对使用 psum 计算总损失的基于 shmap 的批量数据并行神经网络损失函数应用 grad 时,就会出现这种情况。

我们早就知道这个问题。 pmap 也存在类似的问题,不过通过将 grad 保留在 pmap 内部而不是外部来解决。不完整的带有名称的 avals 工作的主要目标是解决这个转置效率问题的一个版本。本文档借鉴了这些想法,同时对其进行扩展和修订,以处理更多情况并更容易实现。事实上,这里提出的解决方案仅影响 shmap 的实现。系统的其余部分不需要更改(暂时)。

本文档的主要目的是定义这个转置效率问题,并提出一个易于实现的解决方案。

本文档不涉及:

  • 数组的逻辑轴名称(这里的轴名称与 shmap 和原始的 pmap 中的一样);

  • 更改自动微分语义(所有的数字和(非)错误保持不变,我们只是让事情更有效率);

  • 允许用户代码反映任何新信息,或者真正影响用户代码。

问题:psumall_gather 的高效转置取决于余切是否在设备之间保持不变#

考虑这个半真实的示例,旨在模拟一个复制参数的批量数据并行损失函数:

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

请注意 out_specs=P(),它表示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。

loss 示例中的大多数细节并不重要。对于我们的目的而言,重要的是我们在最后应用 psum(或者更确切地说是 pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,精简后的版本如下所示:

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

我们甚至通过抑制 mesh 参数来简化表示法。在接下来的示例中,可以从上下文中推断出它。

转置是什么样的?用 t 表示函数转置,我们可以通过应用下面的函数 ¿f1_transpose? 来有效地计算任何 ybart(f1)(ybar)

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

但这不是我们当前得到的 t(f1) 的转置。

相反,当前的转置方法大致是,我们切换 in_specsout_specs,为未映射的输出进行一些除法缩放,然后转置主体。因为 psum 是其自身的转置(作为全归约求和),所以我们最终得到这个转置:

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

这个转置得到的数字是正确的,但很浪费。我们从转置的 in_specs=P() 中静态地知道,ybar 对于每个函数实例都具有相同的值,即它的值对于沿名为 i 的网格轴的设备是设备不变的,但我们却对它应用了 psum!这使用了昂贵的通信,只是为了将每个设备上的值乘以 8。(这里的 8 指的是轴 i 的大小。除以 8 来自原始函数的 out_specs=P();它和简单的 psum 基本上相互抵消了。)

我们哪里做错了?我们没有利用与 f1 的未映射输出对应的余切 ybar 保证是设备不变的这一事实;相反,我们防御性地对它们进行 psum 操作,就好像它们不是设备不变的一样,因为 psum 的转置无法根据它拥有的本地信息确定。有时 psum 是必要的,就像转置 f2 相对于其第一个参数一样:

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

直观地说,如果我们的转置机制能够区分示例 1 和示例 2,我们就可以通过尽可能避免 psum 和除法来做得更好。

低效的例子甚至可以更小。考虑转置这个被诅咒的恒等函数:

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

我们转置的次数越多,它就变得越大。多么尴尬!

而且 psum 不是唯一的罪魁祸首。all_gather 也存在类似的情况:

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

这个程序有点人为。为什么在主体中执行 all_gather 并将结果馈送到未映射的输出中,而不是跳过主体中的 all_gather 并仅使用 out_specs=P('i') 来收集结果?但是,即使它是人为设计的,这个例子仍然展示了一个不必要地执行通信的转置(我们可以只执行一个非通信切片),类似于 psum 的示例 1。

同样类似于 psum 的示例,防御性的 psum_scatter 在某些情况下是必要的:

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

那么我们如何避免这些低效的转置呢?

解决方案#

这里有两个解决方案的想法。它们不是互斥的。但是(剧透)第二个更好,而且这是我们需要的全部。

部分解决方案“P-sum”:构建将 psum 表达为 out_specs 的能力#

这个解决方案有点像稻草人,因为它只会提供一种笨拙的编写程序的方式。而且它甚至无法修复所有问题!但是值得考虑,即使只是为了激发更完整的解决方案。

上面的示例 4 是人为设计的,因为我们可以直接使用 out_specs 而不是主体中的 all_gather

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

f4_better 版本没有任何转置问题,因为转置问题是由主体中的集合操作引起的。

类似地,我们可以通过扩展 out_specs 使它们可以表达求和来修复示例 1:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

因此,在 out_specs 中构建 psum 可以解决示例 1 的转置问题。但是它并没有完全解决示例 3 中被诅咒的恒等转置:

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

这是一个改进,因为程序不会随着我们不断转置而变得越来越大,但我们仍然在进行浪费的通信。

完整解决方案:静态跟踪设备变化的中间值与设备不变的中间值,加上新的原语#

这个解决方案有两个组成部分:

  1. 跟踪值何时保证在特定网格轴上是设备不变的还是设备变化的,以及

  2. psum 分解为两步过程,引入一个新的 pbroadcast 原语,并为 all_gather 及其转置引入新的原语。

从本质上讲,设备不变与设备变化信息的跟踪是类型级别的考虑。但是为了我们第一次实现的方便,我们不需要字面上将信息添加到抽象值或 jaxpr 类型中。在我们开始实现之前,我们将首先使用类型介绍这个想法。

接下来还将讨论如何使用户 API 方便且向后兼容。但是,为了首先介绍这个想法,我们将忽略便利性,而是编写尽可能明确的代码。

在 avals 中跟踪设备不变性(又名带有名称的 avals,复活)#

有时,我们可以仅从静态信息中得知,shmap 主体中某些中间变量的值保证沿网格轴不变,这意味着沿网格轴的函数实例(及其对应的设备)必须都使用相同的值进行计算。我们将此类值称为设备不变值。对于不是设备不变的值,我们将说它们是设备变化的,尽管实际上我们指的是从类型系统的角度来看,它们可能是设备变化的。

为了在类型中编码设备变化,我们将扩展数组的类型语法。我们将编写诸如 x:f32[3,4]{i} 之类的东西,以指示 x 沿网格轴 i(可能)是设备变化的(并且在 shmap 的任何其他网格轴上是设备不变的)。更一般地说,我们将说数组类型语法的语法类似于:

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

我们还将更新类型规则以处理设备变化类型:

  • 对于集合操作之外的一阶原语:

    • 对于多参数的原始操作,操作数的设备方差类型必须相同,且形状也必须相同。例如,mul x:f32[s1]{r1} y:f32[s2][r2] 要求除了 s1 == s2 之外,还需要 r1 == r2

    • 输出的设备方差类型必须与操作数相同。

  • 对于高阶原始操作:

    • 我们只需实例化任何类型变量,包括设备方差类型(并且检查类型是否相等时,会检查它们的设备方差类型是否相等)。

    • (当执行类型推断时,例如对于 cond 的分支,我们会取设备方差类型中轴名称集合的并集。)

  • 对于一阶集合操作:

    • 一个集合操作可以接受设备相关的输入,也可以接受设备无关的输入(沿着与其轴名称参数对应的网格轴);将设备无关的操作数传递给接受设备相关操作数的集合操作,反之亦然,都是错误的。

    • 一个集合操作可以产生设备相关的输出,也可以产生设备无关的输出。

    • 请参阅下表。作为额外的优点,实现此类型检查的任何逻辑都可以取代 shmap 的“静态分析”检查,以确定 shmap 主体函数是否与任何未映射的 out_specs 兼容。

下表总结了集合原始操作的设备方差类型:

名称

设备方差类型

示例

降级为 HLO

转置

psum2

可变 -> 不变

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum (通信)

pbroadcast

pbroadcast

不变 -> 可变

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

无操作 (无通信)

psum

all_to_all

可变 -> 可变

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通信)

all_to_all

axis_index

() -> 可变

idx:i32[]{i} = axis_index('i')

ReplicaId 和一些算术运算 (无通信)

n/a

psum_scatter

可变 -> 可变

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum (通信)

all_gather

all_gather

可变 -> 可变

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather (通信)

psum_scatter

pscatter

不变 -> 可变

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None] (无通信)

all_gather_invariant

all_gather_invariant

可变 -> 不变

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather (通信)

pscatter

这里有一些令人惊讶的地方!

  • 我们引入了几个新的原始操作,包括:

    • pbroadcast,有趣的是它会降级为无操作

    • all_gather_invariant,它与 all_gather 降级为相同的操作,但具有不同的设备方差类型(本质上,all_gather 融合了 pbroadcast,而 all_gather_invariant 没有)。

    • pscatter,它是 all_gather_invariant 的对偶(转置)。

  • all_gather 具有设备相关的结果。

直观地说,引入 pbroadcast 的原因(除了使类型规则起作用外)是使 psum 可以转置为物理无操作。我们需要 all_gather 具有设备相关结果的原因是,这样我们才能将其转置为 psum_scatter;如果我们将其保留为设备无关的结果,我们可能需要一个下游的 pbroadcast,并且该组合会转置为低效的 psum,然后进行切片 / pscatter。因此,我们有一个 “融合到” all_gather 中的 pbroadcast,从而可以有效地转置为 psum_scatter。我们提供 all_gather_invariant 及其转置 pscatter 主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,使用 out_specs 以不同的方式编写很容易)。

有趣的是,psumpbroadcast 转置对对应于用户在使用 pmap 训练 LLM 时引入的 psum_idrevid_psumrev

此系统如何解决低效转置示例#

再次考虑简化的动机示例:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

有了这些新规则,转置就是:

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

其中,计算 pbroadcast 应用根本不涉及通信或 FLOP;它是一个无操作。请注意,如果我们不断转置,主体的大小不会增长;实际上,t(t(f1)) == f1。效率得以实现!

只要我们在需要时使用 pbroadcast 使类型检查通过,我们也不会搞砸其他示例。

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

直观地说,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。

对于 all_gather 示例,示例 4 需要使用 all_reduce_invariant 才能获得有效的转置(尽管最好使用 out_specs 而不是主体中的集合操作)。

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

对于示例 5,使用设备相关的 all_gather 可以按我们期望的方式工作。

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

如何使 API 对用户方便(并向后兼容)#

但是,哪个用户想编写 pbroadcast?哪个开发人员想破坏大量现有的用户代码,这些代码涉及未馈入未映射输出的 psum?我不想!

相反,我们可以自动插入 pbroadcast。这有点类似于我们在 jax.numpy 层进行自动等级提升的方式,插入广播以避免二元运算符中的等级不匹配错误。但是它简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多参数操作,其操作数的设备方差类型不一致时,就取操作数设备方差类型的轴名称集合的并集,并插入 pbroadcast 以将每个操作数提升到结果设备方差类型。

在需要时自动插入 pbroadcast 可能意味着我们对同一个操作数多次应用相同的 pbroadcast,从而创建公共子表达式。当我们转置时,这些可能会变成 psum 的和,而不是和的 psum。我们将依靠编译器来适当地清理它。如果这是一个问题,那么我们可以向 pbroadcast 插入传递添加一些简单的记忆化。

默认情况下,all_gather 的用户 API 将意味着 all_gather_p(而不是 all_gather_invariant_p),涵盖常见情况,这意味着不需要插入 pbroadcast

我们可以在 shmap 上提供一个选项来禁用 pbroadcast 的自动插入,在这种情况下,将由用户来确保类型正确性。对于一些希望明确 psum 在反向传递中出现位置的用户来说,此显式选项可能很有吸引力。

如何实现该解决方案#

使实现轻量级的关键是我们不会将这些类型添加到 avals 或 jaxprs 中。至少,一开始不会。这可能会很昂贵,因为它需要更新 JAX 的其余部分,例如,avals 和 jaxprs 的所有使用者可能需要处理新类型。我们不会再犯同样的错误了!

相反,我们将这些扩展类型保留为 shmap 内部的元数据,就像当前“用于 out_specs 的复制检查”机制是 shmap 内部的一样。实际上,此解决方案是对现有机制的一个相对较小的扩展:它已经跟踪相同的信息;现在我们只是添加 pbroadcast

我们至少有两种选择可以在哪里执行 pbroadcast 插入:

  1. 在转置之前,在转置规则中,我们有一个要转置的计算的 jaxpr;

  2. 在每个 shmap 主体中,无论是急切执行还是分阶段输出,都像当前“对 out_specs 进行复制检查”的机制一样。前者可能更容易,因为我们只需要处理 jaxpr 的情况,并且只需要线性原语。但我们将首先尝试后者,所以这里的实现是对现有复制检查逻辑的严格修订/扩展。

附录:定义和激励具有未映射输入和输出的映射#

为了具体起见,我们将主要关注 shmap,尽管这些相同的想法也适用于例如 pmap,可能也适用于 xmap

in_specs 的相应条目没有提及该网格轴的名称时,参数/输入沿网格轴是未映射的。从逻辑上讲,这意味着沿该网格轴的每个函数实例都获得相同的参数值。对于调用者来说,每个操作数都根据操作数映射到的网格轴进行切片,而对于操作数未映射到的网格轴则不进行切片。

out_specs 的相应条目没有提及该网格轴的名称时,输出沿网格轴是未映射的。从逻辑上讲,这意味着沿该网格轴的每个函数实例必须返回相同的值。对于调用者来说,shmap 的每个结果都是通过连接沿输出映射到的每个函数实例的返回值形成的,而对于输出未映射到的网格轴,仅使用一个值的副本。

有关未映射输入和输出的示例,请参阅 shmap JEP。为了进行比较,在 vmap 中,未映射的输入/输出通过使用 in_axes / out_axesNone(而不是 int)来指示。

以下是我们喜欢 shmap 的未映射输入和输出的原因:

  • pjit 相同的表达能力。 任何 pjit 可以做的事情,shmap 的逃生舱也应该能够做到。否则我们将有一个缺乏的逃生舱!如果我们在 shmap 中没有未映射的输出,那么我们就无法表达与 pjit 相同的批处理并行损失函数计算。

  • 闭合输入。 闭合输入本质上对应于未映射的输入,并且…

  • 转置下的闭包。 一旦我们有了未映射的输入,很自然就可以转置为未映射的输出。

因此,未映射的输出既是规范的,也是有用的!