Pallas 中用于 TPU 的分布式计算#

在本教程中,我们将介绍在 TPU 上使用 Pallas 进行分布式计算的基础知识。我们将学习 TPU 拓扑结构、使用远程 DMA 原语进行通信,以及使用 shard_map 从 JAX 调用分布式内核。我们还将介绍一些更高级的内核编写技术,例如双缓冲、双向带宽优化和嵌套流水线。作为教学示例,我们将学习如何实现 JAX 中的各种集体原语,例如 lax.ppermutelax.all_gatherlax.psumlax.psum_scatter

事先推荐阅读的一些资料

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec

num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v5 lite devices.

TPU 拓扑结构#

TPU 通常部署在多个设备的 pod 中,这些设备通过高带宽芯片间互连 (ICI) 连接,用于 pod 内的通信,其速度远快于典型的网络连接。例如,TPU v5p 的规格表指出,每个芯片的 ICI 带宽为 4.8Tb/s(作为参考,TPU v5p 还具有 21Tb/s 的本地 HBM 带宽)。ICI 使我们能够实现快速且高性能的分布式内核,这些内核需要在 pod 内进行高带宽通信,并使用数据中心网络进行带宽密集度较低的操作的并行化,例如在批处理维度上的数据并行性。

TPU pod 通常以 ND 环面拓扑结构排列。下图给出了几种不同大小的配置示例。

tpu_topologies

环面可以展平成图表,如下所示。每条边(橙色或黑色)是两个设备之间的双向连接。在讨论设备拓扑结构时,您通常会听到关于环的讨论 — 环面的一个关键特征是,当沿着 pod 的轴截取切片时,例如节点 [(0,1), (1, 1), (2, 1), (3, 1)][(0, 1), (1, 1)],我们有一个设备环。这是一个我们可以用来简化 pod 内通信模式的特性。

tpu_torus

远程直接内存访问 (RDMA) 模型#

TPU 通过一种仅推送模型进行通信,该模型称为远程直接内存访问 (RDMA)。TPU 可以发出复制指令,将本地缓冲区中的数据推送到同一 pod 内的另一个设备上的任何缓冲区,该操作与主程序线程异步执行。但是,TPU 只能读取本地存储的数据。这与更传统的多核编程不同,在多核编程中,可以从共享内存读取值并写入值。

异步远程复制操作#

pltpu.make_async_remote_copy 函数用于创建一个远程 DMA 描述符对象,该对象参数化了“发送”操作和“接收”操作。以下是其签名

 def make_async_remote_copy(
     src_ref: Ref,
     dst_ref: Ref,
     send_sem: Ref[SemaphoreType],
     recv_sem: Ref[SemaphoreType],
     device_id: int | tuple[int, ...],
     device_id_type: DeviceIdType
 ) -> AsyncCopyDescriptor:
  • src_ref 是本地 Ref(在任何内存空间中),其中包含您希望发送到另一个设备上的 dst_ref 的数据。

  • dst_ref 是远程 Ref(在任何内存空间中),数据将复制到目标设备上的该位置。

  • send_sem 是一个 DMA 信号量,用于阻塞直到所有数据都已从 src_ref 发送。

  • recv_sem 是一个 DMA 信号量,用于阻塞直到在 dst_ref 处接收到预期数量的字节。DMA 的发送者将写入接收者的 recv_sem

  • device_id 是要发送到的目标设备的设备 ID。

  • device_id_type 指定 device_id 的格式,可以是 LOGICAL 格式(整数设备 ID),也可以是 MESH 格式(逻辑设备网格中的 ND 元组索引)。默认模式为 MESH。

make_async_remote_copy 返回一个描述符对象,您可以使用该对象的 .start() 方法来启动 DMA,使用 .wait_send() 来阻塞 send_sem,使用 .wait_recv() 来阻塞 recv_sem(或使用 .wait() 来同时阻塞两者)。如果一个设备仅预期发送数据,则仅调用 .start().wait_send() 就足够了,同样,如果一个设备仅接收数据,则仅调用 .wait_recv() 就足够了。如果使用 SPMD 模式(其中所有设备都执行 DMA),则每个设备通常都会调用 .start().wait()

dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # Initiate the DMA (non-blocking).
# ... do other work
dma_descriptor.wait_send() # Block until all data has been sent.
dma_descriptor.wait_recv() # Block until all data has been received.

例如,让我们可视化一个 DMA,其中我们考虑 4 个设备(索引为 0、1、2、3)。我们考虑这样一种方案,其中设备 0 复制到设备 1,设备 2 和 3 相互复制。在实践中,我们可以通过使用 @pl.when 在设备 ID 上分支来创建这种不对称的通信模式。

(1) 每个设备都创建 DMA 描述符。设备 0、2 和 3 调用 .start() 以启动从 src_ref 发起的 DMA。设备 1 跳过 .start() 并且不执行任何操作,例如,通过使用 pl.when

rdma_start

(2) 由于 .start() 是非阻塞的,因此每个设备都可以在 DMA 传输过程中自由地执行其他计算。设备 0、2 和 3 调用 .wait_send() 以等待 send_sem,该信号量会阻塞直到所有数据都已发送。

rdma_send

(3) 最后,设备 1、2 和 3 将调用 .wait_recv() 以等待 recv_sem,直到所有数据都已到达 dst_ref

rdma_recv

上述通信模式可以编写如下

def example_kernel(input_ref, output_ref, send_sem, recv_sem):
    device_id = lax.axis_index('x')
    copy_0_to_1 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=1,
    )
    copy_2_to_3 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=3,
    )
    copy_3_to_2 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=2,
    )
    @pl.when(device_id == 0)
    def _():
      copy_0_to_1.start()
      copy_0_to_1.wait_send()
    @pl.when(device_id == 1)
    def _():
      copy_0_to_1.wait_recv()
    @pl.when(device_id == 2)
    def _():
      copy_2_to_3.start()
      copy_2_to_3.wait_send()
      copy_3_to_2.wait_recv()
    @pl.when(device_id == 3)
    def _():
      copy_3_to_2.start()
      copy_3_to_2.wait_send()
      copy_2_to_3.wait_recv()

DMA 信号量#

send_semrecv_sem 是专门用于 DMA 的特殊类型信号量的实例。在向 pallas_call 指定输入规范时,必须使用 tpu.SemaphoreType.DMA 类型来分配它们。

在内部,DMA 信号量可以被认为是整数值进度跟踪器。在 DMA 启动时,本地设备将开始异步递增 send_sem 的值和接收者的 recv_sem。等待信号量会阻塞直到信号量的值达到发送/接收的总字节数;当达到该值时,等待线程会被释放,并且信号量的值会减少相同的量。这意味着要么所有数据都已发送(对于 send_sem),要么所有数据都已接收(对于 dst_sem)。可以使用 pl.semaphore_read 读取信号量的值,但请注意,该值的底层语义可能会因硬件代系而异(例如,该值可能并不完全表示发送的字节数,尽管在推理信号量的行为时,这是一个有用的心理模型)。

路由#

允许发送者将数据发送到同一 pod 内的任何接收者,即使它们没有共享直接连接(此规则的例外情况是对于 TPU v5e,其中设备只能路由到与其自身偏移量为 2 的幂的设备)。TPU 具有内部路由机制,可以将数据传递到路径上通往目的地的下一个设备。但是,不建议以这种方式进行通信,因为作为内核编写者,您无法控制网络争用。我们在本教程中介绍的示例通过仅将数据传输到相邻设备来最大限度地减少低效的通信。

故障模式#

如果错误地使用远程 DMA,您可能会遇到几种难以调试的故障模式。DMA 使用错误的常见症状是崩溃、挂起或静默数据损坏

  • 如果信号量在程序退出时具有无效的非零值,则 Pallas 将崩溃并退出程序。

  • 如果信号量正在等待,但接收到的字节数不足(即,没有发送者,或者如果发送的数据小于接收设备上 dst_ref 的大小),则程序可能会无限期地挂起,等待永远不会发送的字节。在这种情况下,需要重新启动程序。

  • 如果遇到竞争条件,则如果发生两次同时写入或一次同时读取和写入,则可能会出现静默数据损坏。

上述的一些常见原因包括

  • 如果一个设备调用 .wait_recv(),但没有其他设备向其发送数据,则内核可能会挂起。

  • 如果向设备发送的字节数超过了其预期接收的字节数,则也可能会由于非零信号量状态而崩溃。如果发送的字节数较少,则可能会无限期地挂起。

  • 如果启动了 DMA,但没有等待信号量,则程序可能会由于非零信号量状态而崩溃。

  • 如果两个设备复制到同一目标,则可能会由于竞争条件而遇到不确定的结果,或者由于非零信号量状态而崩溃。

示例:右排列 (lax.ppermute)#

让我们深入研究一个非常基本的示例。我们将实现一个执行右排列的内核,其中每个设备将其数据切片发送到其右邻居。

假设我们有一个包含 512 个元素的数组,我们将其分片为大小为 128 的切片,分布在 4 个设备上。每个设备将其切片传递给下一个设备,输出将包含相同的数据,但切片会旋转 1 次。这与 lax.ppermute 操作相同,其中排列设置为 (n, (n+1) % 4)

为了在分布式模式下调用内核,我们将 pallas_call 包装在 shard_map 转换中。从那里,我们可以像编写普通的单设备 Pallas 内核一样编写内核,只不过我们现在可以访问远程 DMA 指令。JAX 集体原语(例如 lax.axis_index)可用于获取 device_id,通过引用传递给 shard_map 的相同命名轴名称,可以用于计算要复制到的目标设备。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=input_ref,
      dst_ref=output_ref,
      send_sem=send_sem,
      recv_sem=recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()


out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    scratch_shapes=(
        # We allocate DMA semaphores in scratch memory.
        [pltpu.SemaphoreType.DMA] * 2
    ),
)
right_permute = pl.pallas_call(
    right_permute_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
    shard_map.shard_map(
        right_permute,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))

xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.ppermute(x, 'x', perm),
        mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
    'Difference |Pallas - lax.ppermute| = ',
    jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
lax.ppermute Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| =  0.0

示例:All-gather (lax.all_gather)#

在下一个示例中,我们将实现 all-gather 集体操作,它在 JAX 中有一个等效的 lax.all_gather。与上面的右移示例仅涉及一对源和目标邻居不同,all-gather 操作需要在所有设备之间进行通信,因此我们必须考虑数据如何在它们之间路由。我们如何实现这一点的具体细节取决于设备拓扑结构,我们假设它是环形的。

环形通信模式#

我们将编写我们的内核,假设拓扑结构是环形的。环形非常适合 TPU,因为沿着环面的任何维度切片都会产生一个环。在编写集体操作时,我们通常只需要一次考虑环面的一维切片,因为环面的不同维度保留用于不同类型的并行性(例如,数据与模型)。

我们将使用的策略是编写一个循环内核,其中在每次迭代中,设备从其左邻居接收分片数组的一个切片,并将先前接收的切片复制到其右邻居。经过 num_devices 次迭代后,每个设备将在其本地 HBM 中拥有整个数组的副本。

all_gather

我们可以重新利用 Pallas 的 grid 参数来实现循环。我们不是像在之前的教程中那样遍历数组的图块,而是将网格设置为 (num_devices,),以表明我们想要遍历设备的数量并使用 pl.program_id 来获取 Pallas 内核内部的循环迭代。以下代码片段演示了如何实现此功能

partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)


def all_gather_kernel(input_ref,
                      output_ref,
                      local_copy_sem,
                      send_sem,
                      recv_sems):
  outer_step = pl.program_id(0)
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  copy_slot = my_id - outer_step
  copy_slot = lax.rem(copy_slot + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    local_copy_op = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[my_id],
      sem=local_copy_sem,
    )
    local_copy_op.start()
    local_copy_op.wait()

  # Copy to our right neighbor.
  # Note that we will also be receiving data from our left neighbor,
  # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
  # that the indices do not need to be symmetric between remote DMAs.
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=output_ref.at[copy_slot],
      dst_ref=output_ref.at[copy_slot],
      send_sem=send_sem,
      recv_sem=recv_sems.at[outer_step],
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()

out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
                pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            ],
            out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            scratch_shapes=(
              # DMA semaphores are allocated in scratch memory.
              # We allocated one semaphore for a local HBM-VMEM copy,
              # and one for the remote send semaphore.
              [pltpu.SemaphoreType.DMA] * 2
              # We additionally allocate one receive semaphore per device.
              # This is to avoid situations where we have multiple
              # DMAs in flight, as we do not want to share a receive
              # semaphore between the DMAs.
              + [pltpu.SemaphoreType.DMA((num_devices-1,))]

            ),
            grid=(num_devices-1,)
        )

all_gather = pl.pallas_call(
      all_gather_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
  )

# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
      shard_map.shard_map(
          all_gather,
          mesh=mesh,
          in_specs=partition,
          out_specs=partition,
          check_rep=False
      )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.all_gather(x, 'x'),
        mesh=mesh, in_specs=partition, out_specs=partition
    )
)(input_arr)

print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
      jnp.mean(jnp.abs(pallas_result - xla_result)))
Input:  (32, 128) [0.9858954  0.54248166 0.9547038  0.954962  ]
Pallas Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
lax.all_gather Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
Difference |Pallas - lax.all_gather| =  0.0

这里值得一提的一个细节是使用了多个接收信号量。因为我们只在接收设备上阻塞,所以发送者仍然有可能在接收者完成处理第一个 DMA 之前发送多个正在进行的 DMA(请参阅下一节和 reduce-sum 示例,其中更详细地讨论了竞争条件)。在这种情况下,我们可能会遇到为多个同时发生的 DMA 使用同一个信号量的情况。为避免这种情况,我们分配 num_devices-1 个信号量,因此没有重复使用的风险。虽然这种竞争条件不太可能发生在如此小的内核上,但在较大的内核上,设备更有可能失去同步并可能导致静默失败。

高级技术#

现在我们已经了解了如何使用远程 DMA 操作编写几个基本内核,我们将介绍更多用于同步和编写高效内核的高级技术。

同步:常规信号量和屏障信号量#

我们在基本教程中实现的示例不需要特殊处理同步,因为所有必要的通信都写入不相交的缓冲区。但是,其他操作可能需要更复杂的通信模式,需要额外的同步原语来避免竞争条件。Pallas 提供了两个额外的原语来帮助解决这个问题:常规信号量和屏障信号量。

常规信号量#

常规信号量是在多个设备之间进行同步的标准工具。信号量从根本上来说是计数器 - 任何设备都可以递增它们,之后设备可以阻塞,直到信号量的值达到特定值(然后递减该值)。

可以在常规信号量上使用的三个主要操作是 signal、wait 和 read

def semaphore_signal(
    sem: Ref[SemaphoreType],
    inc: int,
    device_id: int | tuple[int, ...],
    device_id_type: DeviceIdType
) -> None:
  ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
  
def semaphore_wait(
    semaphore: Ref[SemaphoreType],
    value: int,
) -> None:
  ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
    
def semaphore_read(
    sem: Ref[SemaphoreType],
) -> jax.Array:
  ...  # Returns the current value of `sem` as an `int32[]`.

为了使用常规信号量,可以像分配 DMA 信号量一样分配它们,但是要指定 pltpu.SemaphoreType.REGULAR 而不是 pltpu.SemaphoreType.DMA

信号量必须在 Pallas 程序结束时为零才能成功完成。有两种可能发生错误的情况

  • 如果信号量被过度信号化,程序将以非零 (>0) 信号量结束。在这种情况下,程序将在完成时崩溃。这对于调试很有用,因为非零信号量通常意味着程序内部存在某个错误。

  • 如果信号量被过度等待,程序将挂起在阻塞的 semaphore_wait 调用上,同时等待信号量递增。在这种情况下,需要重新启动设备或程序。

屏障信号量#

屏障信号量是全局分配的信号量,用于在整个程序中同步设备,并确保所有设备都已进入 Pallas 内核。

如果在较大的 XLA 程序的上下文中执行 Pallas 内核,我们需要确保所有通信设备都已进入内核。但是,DMA 和常规信号量都是局部作用域的 - 它们仅被已进入内核的其他设备理解。屏障信号量充当全局理解的信号量,可用于同步,无论设备当前在 XLA 程序中的哪个位置执行。

默认情况下,如果不指定屏障信号量,Pallas 会在程序开始时自动插入一个屏障信号量。但是,编写自己的屏障信号量可能更有效。屏障信号量类似于常规信号量,因为它们是计数器,可以通过 semaphore_signal 递增,并且可以通过 semaphore_wait 递减。它们是通过在内核中调用 get_barrier_semaphore() 创建的。通常,我们在内核开始时使用屏障一次,以与所有正在通信的设备同步。

from jax.experimental.pallas import tpu as pltpu

def example_kernel(...):
  # Use barrier semaphores at the beginning of a kernel.
  # is_start_of_kernel = ...
  # right_neighbor = ...
  # ...
  @pl.when(is_start_of_kernel)
  def _():
    barrier_sem = pltpu.get_barrier_semaphore()
    # Increment the semaphore of your right neighbor.
    pltpu.semaphore_signal(
          barrier_sem,
          device_id=right_neighbor,
          device_id_type=pltpu.DeviceIdType.LOGICAL,
    )
    # Wait until your left neighbor has incremented your semaphore
    pltpu.semaphore_wait(barrier_sem, 1)
  # ...

使用屏障信号量时,必须将 collective_id 编译器参数传递给 pallas_call,以指定正在使用哪个屏障信号量。TPU 有少量固定的可用屏障信号量(通常在 20-30 个左右),因此应该谨慎使用。为了确保正确性,只有共享相同通信模式的内核才应该使用相同的 collective_id。例如,如果两个内核仅与同一网格轴上的邻居同步,则允许它们共享相同的 collective_id。但是,如果两个内核沿不同的轴同步,则它们必须具有不同的 collective_id。否则可能会导致难以调试的竞争条件。

kernel = pl.pallas_call(
      example_kernel,
      ...,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

双缓冲#

为了避免从另一个设备也在写入的本地 Ref 读取并创建竞争条件,一种有用的技术是“双缓冲”策略,其中我们为每个目标值分配两个 Ref。在每次迭代中,一个 Ref 将被指定为“工作”槽,另一个将被指定为“接收”槽。设备可以自由使用工作槽进行计算,但只会将数据复制到其邻居的接收槽。工作槽和接收槽在每次迭代时交替,因此一旦复制完成,旧的接收槽就会成为新的工作槽,反之亦然。正确使用此方案,永远不会从同一缓冲区读取和写入数据。

以下代码框架演示了如何使用双缓冲。我们在变量 iteration 中保留一个运行迭代计数器,并且 working_slotreceiving_slot 在每次迭代时在 0 和 1 之间交替。dst_ref 被分配为双缓冲区,大小为 [2, ...]。在每次迭代中,我们使用 dst_ref.at[working_slot, ...] 从工作槽读取,并使用该值执行计算。同时,我们复制到邻居的 dst_ref.at[receiving_slot],以避免覆盖他们的 working_slot 值。通过以这种方式构建我们的通信,可以在最小化竞争条件风险的同时,将远程 DMA 的通信延迟与本地计算重叠。

def kernel(...):
  # ...
  iteration = pl.program_id(0)
  working_slot = lax.rem(iteration, 2)
  receiving_slot = 1 - working_slot
  # ...

  local_copy_op = pltpu.make_async_copy(
    src_ref=dst_ref.at[working_slot, ...],
    dst_ref=local_scratch_ref,
    sem=local_copy_sem,
  )
  local_copy_op.start()
  remote_copy_op = pltpu.make_async_remote_copy(
    src_ref=src_ref,
    dst_ref=dst_ref.at[receiving_slot, ...],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=target_device,
    device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  
  local_copy_op.wait()
  # ... do work on local_scratch while waiting for async_copy_op to finish.
  remote_copy_op.wait()

在同步方面,如果所有设备都在相同的迭代中执行,则双缓冲构造会起作用。如果发送者设法比接收者提前一个迭代,则其 working_slotreceiving_slot 索引将与接收者相比翻转,这意味着它可能在接收者从中读取的同一时间写入 working_slot。为了避免这种情况,可能需要使用信号量将发送者与接收者同步,或添加额外的缓冲槽(“三倍”、“四倍”或 N 缓冲),以允许额外的提前运行,但代价是占用更多内存。在我们之前的 all_gather 示例中,请注意,内核包含一个带有 N 个槽的接收缓冲区,这完全避免了竞争条件。在我们的下一个内核中,我们将改为通过一个使用带有显式同步的双缓冲的示例。

示例:全规约求和 (lax.psum)#

我们现在将使用双缓冲和信号量进行同步来实现全规约求和内核。对于熟悉 JAX 中集体操作的人来说,等效操作是 lax.psum。全规约是一种标准集体操作,其目标是沿数组的轴进行规约,但数组会跨多个设备分片。

reduce_sum_1

在上面的示例中,我们有一个数组 [5, 2, 1, 3] 跨 4 个设备分片。全规约求和操作将对所有值求和,并将结果复制到每个设备上,从而产生在所有 4 个设备上分片的结果 [11, 11, 11, 11]。

全规约的朴素实现是将所有必需的值收集到每个设备上,然后进行规约。但是,我们可以通过将通信与计算交错来提高此实现的性能。交错的单向全规约可以可视化如下。在每次迭代中,我们从左邻居接收一个输入值,同时将输入传递给下一个邻居,同时用本地累加器递增它。在 N-1 次迭代后,每个设备都将在其内存中拥有完整总和的副本。

reduce_sum_2

将它们放在一起#

以下内核演示了如何将这些原则组合成一个功能内核。

序言(当 outer_step==0 时执行)首先与两个邻居启动一个屏障,以确保它们也已进入内核。它还处理所有 Ref 的初始化,并处理到右邻居“工作”槽的第一个远程复制。

主体会假定一个值已经从上次迭代或序言复制到我们的本地工作槽中。一个复杂因素是我们的目标缓冲区位于 HBM 中,但是我们需要在执行算术运算之前将值加载到 VMEM 中。因此,我们同时将工作槽值复制到我们的 VMEM (receive_scratch) 中,并将值传递给右邻居的接收槽。一旦该值已复制到我们的 VMEM 中,我们就可以将其累加到我们的结果(包含在 o_ref 中)。

如果一个设备比其右邻居提前运行一个循环,则可能会发生微妙的竞争条件。在这种情况下,它可能会在接收方正在从中读取的同时,复制到接收方的 working_slot 中。为了避免这种情况,每个设备将在复制到右邻居的 dst_ref 之前,阻塞一个 REGULAR 信号量,直到它发出信号表明它已完成从其 working_slot 的读取。对于像本例这样的小内核,很少会触发此竞争条件,但如果使用 pltpu.delay 指令人为挂起设备,则可以显式触发它。

请注意,这不是一个最佳或完全通用的内核,因为块大小必须完全适合 VMEM,并且我们可以更好地交错通信和累加。我们将在后面的章节中讨论这些优化。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def all_reduce_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    copy_sem,
    remote_recv_sem,
    remote_send_sem,
    capacity_sem,
    receive_scratch,
):
  outer_step = pl.program_id(0)
  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot

  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(left_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # Initialize o_ref, acc_scratch, and hbm_scratch.
    o_ref[...] = jnp.zeros_like(o_ref)
    receive_scratch[...] = jnp.zeros_like(receive_scratch)
    initial_copy = pltpu.make_async_remote_copy(
        src_ref=x_ref,
        dst_ref=hbm_scratch.at[working_slot],
        send_sem=remote_send_sem,
        recv_sem=remote_recv_sem,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    initial_copy.start()
    initial_copy.wait()

  # Signal to our left neighbor that we are ready to receive.
  # Without this signal, our left neighbor can be >=1 iteration ahead,
  # meaning it could write into our working slot.
  pltpu.semaphore_signal(
      capacity_sem,
      inc=1,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # Copy the partial result our left neighbor sent to us into VMEM for
  # computation.
  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=receive_scratch,
      sem=copy_sem,
  )
  local_copy.start()

  # Block until our right neighbor is ready to receive.
  pltpu.semaphore_wait(capacity_sem, 1)
  # Pass the value to our right neighbor.
  remote_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=hbm_scratch.at[receiving_slot],
      send_sem=remote_send_sem,
      recv_sem=remote_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy.start()
  # Finish local copy and accumulate while remote_copy is happening.
  local_copy.wait()
  o_ref[...] += receive_scratch[...]
  # Block until remote copy finishes.
  remote_copy.wait()


out_shape = (
    jax.ShapeDtypeStruct((8, 128), jnp.float32),
    # We allocate the double-buffer as a Pallas output so that it is
    # resident in HBM.
    jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        # Our input lives in VMEM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        # Our output lives in VMEM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        # Our double-buffer lives in HBM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices,),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3
        + [pltpu.SemaphoreType.REGULAR]  # capacity_sem
        + [pltpu.VMEM((8, 128), jnp.float32)]  # receive_scratch
    ),
)

kernel = pl.pallas_call(
    all_reduce_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
    compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

pallas_result = jax.jit(
    shard_map.shard_map(
        kernel,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]


def lax_sum(x):
  return lax.psum(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
    )
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas result =  [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result =  [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| =  1.4959369e-08

超前运行和竞争条件#

作为一般经验法则,为了最大限度地提高性能,我们希望允许设备尽可能在不同步的情况下超前于其他设备运行,而不会牺牲程序的正确性。虽然我们可以在每次迭代开始时强制所有设备进行屏障,但这会将程序的性能瓶颈限制在每个循环中最慢的设备上。通过放松同步并允许适量的超前运行,我们可以更好地适应迭代和设备之间的延迟差异,因为在一次迭代中速度较慢的设备可以在下一次迭代中赶上。

在我们之前编写的全规约内核中,我们允许设备超前运行,但与邻居相比,少于一次迭代(但是,非相邻设备可能相差 1 次以上的迭代)。要了解为什么信号量同步是必要的,请考虑一个设备(例如设备 2)挂起并落后于其他设备的情况。RDMA 没有“握手” - 只有接收方在等待数据到达时会被阻塞。因此,每个设备最多可以超前运行一次迭代,然后才会被阻塞,等待下一个 RDMA 到达。如果我们有 N 个设备,这意味着最后一个设备最多可以比第一个设备提前 N 次迭代。

race_condition

如果在另一个方向(强制发送者阻塞)不添加同步,则设备 1 可能会比设备 2 最多提前 N 次迭代(N = num_devices),发送多次写入并覆盖过程中的值。为了解决我们之前编写的 all_reduce 内核中的这个问题,我们实现了一个“握手”协议,其中接收方会向发送方发出信号,表示它已准备好接收,然后发送方才开始发出下一个 RDMA。

双向通信#

在我们之前的内核中,我们以从左到右的环形单向方式进行通信。但是,由于 ICI 连接是双向的,因此我们实际上浪费了一半的总带宽,而没有从右到左发送值。在下一个内核中,我们将演示一个在两个方向上进行通信以最大化 ICI 带宽的示例。

示例:双向分散规约 (lax.psum_scatter)#

分散规约操作是全规约后接分散的组合。或者,全规约是分散规约后接全收集的组合。

下图描述了此操作的语义。我们假设每个设备都从部分和的集合开始(用字母 + 数字表示,例如 A0)。目标是沿一个轴(数字)进行规约,同时沿另一个轴(字母)进行分片。

reduce_scatter_1

为了实现双向通信策略,我们将每个输入块切成两半,并为每一半指定一个方向。每个块的上半部分将从右到左传递,下半部分将从左到右传递。与我们之前的全规约和全收集内核的通信模式的第二个偏差是,我们还将传递累加器或部分和,并将输入保留在每个本地设备上。这与之前传递输入但将累加器保留在设备本地的示例相反。传递累加器更适合此问题,因为与全规约相反,输入中的大多数数据不是将本地存储在设备上的输出的一部分。(例如,上图中的 B0C0D0 将不会存储在末尾保持 A 的设备上)。

下图说明了这种通信模式,其中彩色框表示累加器(不是输入!)。最初,累加器只是包含在输入中的值。在该算法的每次迭代中,我们将从每个方向的邻居处接收到部分和。然后,我们计算输入的正确切片以累加到部分缓冲区中,然后将新的部分和传递给下一个邻居。在 N 次迭代后,累加器将通过每个设备,这意味着它最终将保持全部总和。

reduce_scatter_2

就内核的构建而言,我们在 Pallas 网格中引入了一个额外的 phase 维度,该维度表示我们当前正在计算哪个累加器(左或右)。我们让 phase=0 表示累加器向左移动,让 phase=1 表示累加器向右移动。然后,我们对两个阶段进行流水线处理,这样在计算一个阶段的结果时,我们会在相反的方向传输先前计算的值,以便为下一阶段做好准备。例如,当我们在 phase=0(左)时,我们首先启动一个 DMA,将我们在上一次迭代中计算的结果传输到我们的右邻居(右 DMA)。然后,我们将累加到左缓冲区并将结果保存到 HBM。然后,我们等待右 DMA 完成,以便为 phase=1(右)做好准备。

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)

LEFT = 0
RIGHT = 1


def mod(x, n):
  return lax.rem(x + n, n)


def signal(left_or_right, semaphore):
  my_id = lax.axis_index('x')
  if left_or_right == LEFT:
    neighbor = mod(my_id - 1, num_devices)
  else:
    neighbor = mod(my_id + 1, num_devices)
  pltpu.semaphore_signal(
      semaphore,
      inc=1,
      device_id=(neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    local_copy_sem,
    left_recv_sem,
    left_send_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
    accum_scratch,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  # Slices can be specified using pl.ds(start, size)
  left_copy_slice = pl.ds(0, block_size[0] // 2)
  right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
  current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      # Note: Right copy is flipped with regards to slots since we are copying
      # to the next outer_step iteration.
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(left_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
    o_ref[...] = jnp.zeros_like(o_ref[...])
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  # --- Body ---
  # At the beginning of our kernel body, we start a DMA which copies
  # the result we computed in the previous phase to our neighbor.
  # This allows us to overlap the communication of sending our previous phase
  # with the computation for the current phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot, current_phase_slice],
      dst_ref=accum_scratch,
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]

    @pl.when(phase == RIGHT)
    def _():
      accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]

  local_copy = pltpu.make_async_copy(
      src_ref=accum_scratch,
      dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  # At the end of our kernel body, we wait on the DMA of the previous phase
  # to make sure the results are ready for the next phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # --- Epilogue ---
  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      o_ref[left_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      o_ref[right_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32),  # output
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, block_size[0], block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
        + [
            pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
        ]  # accum_scratch
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, block_size[0], block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047  0.59993696 0.9714314  0.24692321 0.01347649
 0.01857424 0.24841607 0.86097646 0.8261659  0.9753758  0.6902338
 0.4431417  0.963323   0.3158517  0.535548  ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

嵌套的远程和本地 DMA 流水线#

我们编写的之前的全规约和分散规约内核的一个限制是,我们通过远程 DMA 复制的块必须足够小,以适合我们用于累加的工作 VMEM。对于某些内核,使用更大的块大小来更好地利用 TPU 可能是有利的。例如,矩阵乘法需要 \(O(N^3)\) 数量级的计算操作,但只有 \(O(N^2)\) 内存传输。因此,我们希望在设备之间传输的每个工作块都足够大,这样操作就会受到计算的限制,并且我们可以使用流水线隐藏通信成本。作为参考,TPU 的 VMEM(对于 v4/v5 代)通常在 10-100MB 左右,而 HBM 的范围为 10-100GB。

为了解决这个问题,我们需要能够编写一个“内部内核”,该内核在处理设备之间更大的 HBM-HBM 传输流水线的“外部内核”内部处理本地 HBM-VMEM 流水线。Pallas 提供了一个 API,用于使用 emit_pipeline 函数构建嵌套流水线。emit_pipeline 的基本调用签名遵循标准 pallas_call 的签名,方法是为输入和输出指定 gridBlockSpec

def emit_pipeline(
    kernel: Callable,
    grid: tuple[int],
    in_specs: PyTree[BlockSpec] = None,
    out_specs: PyTree[BlockSpec] = None,
    should_accumulate_out: bool = False,
    dimension_semantics: tuple[GridDimensionSemantics] = None,
) -> Callable:
  ... # Returns a custom pipeline given an inner kernel and BlockSpecs.

实际上,可以将 pallas_call 本身视为只是 emit_pipeline 的包装器。因为我们的外部内核只涉及远程 HBM-HBM 传输,所以我们没有使用 pallas_call 为 HBM-VMEM 传输提供的任何内置流水线。以下代码骨架演示了使用此模式的典型程序结构是什么样的


def outer_kernel(...):
  # ... do work to pipeline remote HBM-HBM transfers (outer kernel)

  def inner_kernel(...):
    # ... do work (inner kernel)
  pltpu.emit_pipeline(
          inner_kernel,
          grid=inner_grid,
          in_specs=...,
          out_specs=...,
  )(inner_kernel_args)
  # ... do more work (outer kernel)

pl.pallas_call(
  outer_kernel,
  grid=outer_grid,
  in_specs=...
  out_specs=...
  scratch=inner_kernel_allocs
)

示例:带有大型 HBM 块的分散规约#

在下一个示例中,我们将修改之前的分散规约示例以利用嵌套的内部流水线。请注意,reduce_scatter 的通信和计算成本都与输入的大小呈线性关系,因此我们不一定希望看到该操作在较大的块大小时变为计算限制。此示例仅用于演示如何使用流水线发射器。

我们将增大外部内核的块大小,使其不适合放置在 VMEM 中,并将所有输入和输出都分配在 HBM 中(memory_space=TPUMemorySpace.Any)。与我们之前的内核相比,唯一的主要变化是内核的主体部分,即进行累加的地方。我们不再手动从 HBM 复制到 VMEM,进行累加,然后再复制回 HBM,而是使用 emit_pipeline 来处理内存传输。累加在一个内部内核中进行,该内核具有更小、更适合 VMEM 的块大小。

在我们之前的内核中,我们有以下内核主体将数据从 HBM 复制到 VMEM 累加器,进行递增,然后将结果复制回 HBM

local_copy = pltpu.make_async_copy(
    src_ref=hbm_scratch.at[working_slot, current_phase_slice],
    dst_ref=accum_scratch,
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
  @pl.when(phase == RIGHT)
  def _():
    accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
    src_ref=accum_scratch,
    dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()

我们新的内核用以下 emit_pipeline 调用替换了它

def inner_kernel(input_ref, accum_ref):
  accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
                                     in_specs=[inner_block_spec],
                                     out_specs=inner_block_spec,
                                     should_accumulate_out=True,
                                     grid=inner_grid)
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
                   hbm_scratch.at[working_slot, left_copy_slice],
    )
  @pl.when(phase == RIGHT)
  def _():
    accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
                   hbm_scratch.at[working_slot, right_copy_slice],
    )

完整的内核如下

partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)

# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (4096, 4096)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(
        outer_block_size[0] * num_devices,
        outer_block_size[1] * num_devices,
    ),
)
input_arr = jax.device_put(input_arr, sharding)


inner_grid = (
    outer_block_size[0] // inner_block_size[0] // 2,
    outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
    index_map=lambda i, j: (i, j),
    block_shape=inner_block_size,
    memory_space=pltpu.TPUMemorySpace.ANY,
)


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    left_recv_sem,
    left_send_sem,
    copy_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
  right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
  current_phase_slice = pl.ds(
      phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
  )

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(left_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(right_neighbor,),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(left_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(right_neighbor,),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  # --- Body ---
  def inner_kernel(input_ref, accum_ref):
    # We do not explicitly use += because we set should_accumulate_out=True.
    accum_ref[...] = input_ref[...]

  accum_pipeline = pltpu.emit_pipeline(
      inner_kernel,
      in_specs=[inner_block_spec],
      out_specs=inner_block_spec,
      should_accumulate_out=True,
      grid=inner_grid,
  )

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_pipeline(
          x_ref.at[left_copy_device, left_copy_slice],
          hbm_scratch.at[working_slot, left_copy_slice],
      )

    @pl.when(phase == RIGHT)
    def _():
      accum_pipeline(
          x_ref.at[right_copy_device, right_copy_slice],
          hbm_scratch.at[working_slot, right_copy_slice],
      )

  # --- Epilogue ---
  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    output_copy = pltpu.make_async_copy(
        src_ref=hbm_scratch.at[working_slot, current_phase_slice],
        dst_ref=o_ref.at[current_phase_slice],
        sem=copy_sem,
    )
    output_copy.start()
    output_copy.wait()

    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct(
        (outer_block_size[0], outer_block_size[1]), jnp.float32
    ),
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, outer_block_size[0], outer_block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(
      num_devices, outer_block_size[0], outer_block_size[1]
  )
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Now we compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182  0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

最终注意事项#

Megacore#

某些 TPU 在 Megacore 配置中包含多个内核。在这种配置中,我们的一般建议是从单个内核启动 DMA,并且仅执行 HBM-HBM 传输。为此,将其中一个网格轴设置为内核数量(可以通过 jax.devices()[0].num_cores 获取),并将 dimension_semantics 设置为 "parallel"。然后,您可以使用 core_index = pl.program_id(axis) 来获取沿该轴的内核索引,并使用 @pl.when(core_index==i) 来执行特定于该内核的代码。

与 XLA 的交互#

在本教程中,我们介绍了一些内核示例,它们复制了 JAX 中集体操作的功能,例如 lax.all_gatherlax.psumlax.psum_scatter。需要注意的一个重要警告是,Pallas 内核对于 XLA 编译器来说有些不透明,可能会导致它错过一些它通常会执行的优化。例如,XLA 可以异步调度集体操作,以便在不编写自定义内核的情况下交错通信和计算。当涉及到 Pallas 内核时,不能保证会发生这种情况,因此对程序进行性能分析以查看这是否是一个问题非常重要。另一个例子是,我们在本教程中用于生成嵌套管道的 emit_pipeline 函数对 XLA 编译器不可见,因此无法与相邻的操作融合。

下一步#

读者可以进行的优秀后续练习包括实现分布式矩阵乘法、实现 lax.all_to_all,以及放松同步以允许额外的先行执行。