Pallas TPU 上的分布式计算#

在本教程中,我们将介绍在 TPU 上使用 Pallas 进行分布式计算的基础知识。我们将学习 TPU 拓扑结构,使用远程 DMA 原语进行通信,以及使用 shard_map 从 JAX 调用分布式内核。我们还将介绍一些更高级的内核编写技术,例如双缓冲、双向带宽优化和嵌套流水线。作为教育示例,我们将学习如何实现 JAX 中的各种集体原语,例如 lax.ppermutelax.all_gatherlax.psumlax.psum_scatter

一些建议的预先阅读材料

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec

num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v5 lite devices.

TPU 拓扑结构#

TPU 通常以多个设备的 Pod 形式部署,这些设备通过高带宽芯片间互连 (ICI) 相互连接,用于 Pod 内部的通信,其速度远快于典型的网络连接。例如,TPU v5p 的规格表中指出每个芯片的 ICI 带宽为 4.8Tb/s(作为参考,TPU v5p 还具有 21Tb/s 的本地 HBM 带宽)。ICI 使我们能够实现快速且高效的分布式内核,这些内核需要 Pod 内的高带宽通信,并使用数据中心网络来实现带宽密集型操作较少的并行化,例如批次维度上的数据并行化。

TPU Pod 通常以 ND 环面拓扑结构排列。下图给出了不同大小配置的几个示例。

tpu_topologies

将环面展平为图形,可以将其可视化如下。每条边(橙色或黑色)是两个设备之间的双向连接。在讨论设备拓扑结构时,您通常会听到有关环形结构的讨论——环面的一个关键特征是,当沿着 Pod 的轴线(例如节点[(0,1), (1, 1), (2, 1), (3, 1)][(0, 1), (1, 1)])进行切片时,我们会得到一个设备环。这是一个我们可以用来简化 Pod 内通信模式的功能。

tpu_torus

远程直接内存访问 (RDMA) 模型#

TPU 通过一种称为远程直接内存访问 (RDMA) 的仅推送模型进行通信。TPU 允许发出复制指令,将数据从本地缓冲区推送到同一 Pod 内另一个设备上的任何缓冲区,该操作与主程序线程异步执行。但是,TPU 只能读取存储在本地的数据。这与更传统的多分支编程形成对比,在多分支编程中,可以读取和写入共享内存中的值。

异步远程复制操作#

pltpu.make_async_remote_copy 函数用于创建远程 DMA 描述符对象,该对象参数化“发送”操作和“接收”操作。以下是其签名

 def make_async_remote_copy(
     src_ref: Ref,
     dst_ref: Ref,
     send_sem: Ref[SemaphoreType],
     recv_sem: Ref[SemaphoreType],
     device_id: int | tuple[int, ...],
     device_id_type: DeviceIdType
 ) -> AsyncCopyDescriptor:
  • src_ref 是包含您希望发送到另一个设备上dst_ref的数据的本地Ref(位于任何内存空间中)。

  • dst_ref 是远程Ref(位于任何内存空间中),数据将复制到目标设备上的该位置。

  • send_sem 是一个 DMA 信号量,用于阻塞,直到所有数据都已从src_ref发送。

  • recv_sem 是一个 DMA 信号量,用于阻塞,直到在dst_ref处接收到预期字节数。DMA 的发送方将写入接收方的recv_sem

  • device_id 是要发送到的目标设备的设备 ID。

  • device_id_type 指定device_id的格式,可以是 LOGICAL 格式(整数设备 ID),也可以是 MESH 格式(逻辑设备网格中的 ND 元组索引)。默认模式为 MESH。

make_async_remote_copy 返回一个描述符对象,您可以在该对象上使用.start()方法启动 DMA,并使用.wait_send()send_sem上阻塞,并使用.wait_recv()recv_sem上阻塞(或使用.wait()同时阻塞两者)。如果设备仅预期发送数据,则只需调用.start().wait_send()就足够了,同样,如果设备仅接收数据,则只需调用.wait_recv()就足够了。如果使用所有设备都执行 DMA 的 SPMD 模式,则每个设备通常都会同时调用.start().wait()

dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # Initiate the DMA (non-blocking).
# ... do other work
dma_descriptor.wait_send() # Block until all data has been sent.
dma_descriptor.wait_recv() # Block until all data has been received.

例如,让我们可视化一个 DMA,其中我们考虑 4 个设备(索引为 0、1、2、3)。我们考虑一种方案,其中设备 0 复制到设备 1,设备 2 和 3 互相复制。在实践中,我们可以通过使用@pl.when根据设备 ID 分支来创建这种非对称通信模式。

(1) 每个设备创建 DMA 描述符。设备 0、2 和 3 调用.start()以从src_ref启动 DMA。设备 1 跳过.start()并且不执行任何操作,例如通过使用pl.when

rdma_start

(2) 由于.start()是非阻塞的,因此每个设备在 DMA 正在进行时都可以自由地执行其他计算。设备 0、2 和 3 调用.wait_send()以等待send_sem,该信号量会阻塞,直到所有数据都已发送。

rdma_send

(3) 最后,设备 1、2 和 3 将调用.wait_recv()以等待recv_sem,直到所有数据都已到达dst_ref

rdma_recv

上述通信模式可以写成如下

def example_kernel(input_ref, output_ref, send_sem, recv_sem):
    device_id = lax.axis_index('x')
    copy_0_to_1 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=1,
    )
    copy_2_to_3 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=3,
    )
    copy_3_to_2 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=2,
    )
    @pl.when(device_id == 0)
    def _():
      copy_0_to_1.start()
      copy_0_to_1.wait_send()
    @pl.when(device_id == 1)
    def _():
      copy_0_to_1.wait_recv()
    @pl.when(device_id == 2)
    def _():
      copy_2_to_3.start()
      copy_2_to_3.wait_send()
      copy_3_to_2.wait_recv()
    @pl.when(device_id == 3)
    def _():
      copy_3_to_2.start()
      copy_3_to_2.wait_send()
      copy_2_to_3.wait_recv()

DMA 信号量#

send_semrecv_sem是专门为 DMA 使用而保留的特殊类型信号量的实例。在为pallas_call指定输入规范时,必须使用tpu.SemaphoreType.DMA类型分配它们。

在内部,DMA 信号量可以被认为是整数值进度跟踪器。在 DMA 启动时,本地设备将开始异步增加send_sem和接收方的recv_sem的值。等待信号量将阻塞,直到信号量的值达到已发送/接收的数据的总字节数;当达到该值时,等待的线程将被释放,并且信号量的值将减少相同的数量。这意味着所有数据都已发送(对于send_sem)或所有数据都已接收(对于dst_sem)。可以使用pl.semaphore_read读取信号量的值,但请注意,该值的底层语义可能会在硬件代之间发生变化(例如,该值可能并不完全代表已发送的字节数,尽管在推断信号量的行为时,这是一个有用的心理模型)。

路由#

发送方允许将数据发送到同一 Pod 内的任何接收方,即使它们不共享直接连接(此规则的例外情况是 TPU v5e,其中设备只能路由到自身 2 的幂偏移量)。TPU 具有内部路由机制,可以将数据传递到路径上到目标设备的下一个设备。但是,不建议以这种方式进行通信,因为作为内核编写者,您无法控制网络争用。本教程中我们将介绍的示例通过仅将数据传输到相邻设备来最大程度地减少低效通信。

故障模式#

如果远程 DMA 使用不正确,您可能会遇到一些难以调试的故障模式。错误 DMA 使用的一般症状是崩溃、挂起或静默数据损坏

  • 如果信号量在程序退出时具有无效的非零值,则 Pallas 将崩溃并退出程序。

  • 如果等待信号量,但接收到的字节数不足(即没有发送方,或者已发送的数据小于接收设备上dst_ref的大小),则程序可能会无限期挂起,等待永远不会发送的字节。在这种情况下,需要重新启动程序。

  • 如果遇到竞争条件,如果发生两个同时写入或同时读取和写入,则可能会出现静默数据损坏。

以上的一些常见原因包括

  • 如果设备调用.wait_recv(),但没有其他设备发送数据,则内核可能会挂起。

  • 如果向设备发送的字节数超过其预期接收的字节数,则它也可能会由于非零信号量状态而崩溃。如果发送的字节数较少,则它可能会无限期挂起。

  • 如果启动 DMA 但未等待信号量,则程序可能会由于非零信号量状态而崩溃。

  • 如果两个设备复制到同一目标,则您可能会遇到由于竞争条件导致的不确定结果,或者由于非零信号量状态导致的崩溃。

示例:右置换 (lax.ppermute)#

让我们深入了解一个非常基本的示例。我们将实现一个执行右置换的内核,其中每个设备将其数据切片发送到其右侧邻居。

假设我们有一个包含 512 个元素的数组,我们将其分成大小为 128 的切片,跨 4 个设备。每个设备将其切片传递给下一个设备,输出将包含相同的数据,但切片旋转 1 位。这与lax.ppermute操作相同,其中置换设置为(n, (n+1) % 4)

为了以分布式模式调用内核,我们将pallas_call包装在shard_map转换中。从那里,我们可以像编写普通的单设备 Pallas 内核一样编写内核,除了我们现在可以访问远程 DMA 指令。JAX 集体原语(如lax.axis_index)可用于获取device_id,通过引用传递给shard_map的相同命名轴名称,可以用来计算要复制到的目标设备。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=input_ref,
      dst_ref=output_ref,
      send_sem=send_sem,
      recv_sem=recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()


out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    scratch_shapes=(
        # We allocate DMA semaphores in scratch memory.
        [pltpu.SemaphoreType.DMA] * 2
    ),
)
right_permute = pl.pallas_call(
    right_permute_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
    shard_map.shard_map(
        right_permute,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))

xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.ppermute(x, 'x', perm),
        mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
    'Difference |Pallas - lax.ppermute| = ',
    jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
lax.ppermute Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| =  0.0

示例:全部收集 (lax.all_gather)#

在接下来的示例中,我们将实现全聚合集体操作,它在 JAX 中等效于 lax.all_gather。与上面提到的右置换示例(仅涉及一对源和目标邻居)相比,全聚合操作需要所有设备之间的通信,因此我们必须考虑数据如何在它们之间路由。我们如何实现这一点的具体细节取决于设备拓扑,我们假设它是一个环形拓扑。

环形通信模式#

我们将假设环形拓扑编写内核。环形拓扑非常适合 TPU,因为沿着环面任何维度切片都会产生一个环形。在编写集体操作时,我们通常只需要考虑环面的 1D 切片,因为环面的不同维度用于不同类型的并行性(例如,数据与模型)。

我们将使用的策略是编写一个循环内核,在每次迭代中,设备从其左侧邻居接收分片数组的一个切片,并将先前接收的切片复制到其右侧邻居。经过 num_devices 次迭代后,每个设备都将在其本地 HBM 中拥有整个数组的副本。

all_gather

我们可以重新利用 Pallas 的 grid 参数来实现循环。与其像我们在之前的教程中那样迭代数组的块,我们改为将网格设置为 (num_devices,) 以指示我们希望循环遍历设备的数量,并使用 pl.program_id 在 Pallas 内核中获取循环迭代。以下代码片段演示了如何实现这一点

partition = P('x', None)
devices = mesh_utils.create_device_mesh((num_devices, 1))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)


def all_gather_kernel(input_ref,
                      output_ref,
                      local_copy_sem,
                      send_sem,
                      recv_sems):
  outer_step = pl.program_id(0)
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  copy_slot = my_id - outer_step
  copy_slot = lax.rem(copy_slot + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    local_copy_op = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[my_id],
      sem=local_copy_sem,
    )
    local_copy_op.start()
    local_copy_op.wait()

  # Copy to our right neighbor.
  # Note that we will also be receiving data from our left neighbor,
  # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
  # that the indices do not need to be symmetric between remote DMAs.
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=output_ref.at[copy_slot],
      dst_ref=output_ref.at[copy_slot],
      send_sem=send_sem,
      recv_sem=recv_sems.at[outer_step],
      device_id=(right_neighbor, 0),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()

out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
                pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            ],
            out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            scratch_shapes=(
              # DMA semaphores are allocated in scratch memory.
              # We allocated one semaphore for a local HBM-VMEM copy,
              # and one for the remote send semaphore.
              [pltpu.SemaphoreType.DMA] * 2
              # We additionally allocate one receive semaphore per device.
              # This is to avoid situations where we have multiple
              # DMAs in flight, as we do not want to share a receive
              # semaphore between the DMAs.
              + [pltpu.SemaphoreType.DMA((num_devices-1,))]

            ),
            grid=(num_devices-1,)
        )

all_gather = pl.pallas_call(
      all_gather_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
  )

# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
      shard_map.shard_map(
          all_gather,
          mesh=mesh,
          in_specs=partition,
          out_specs=partition,
          check_rep=False
      )
)(input_arr)

# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.all_gather(x, 'x'),
        mesh=mesh, in_specs=partition, out_specs=partition
    )
)(input_arr)

print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
      jnp.mean(jnp.abs(pallas_result - xla_result)))
Input:  (32, 128) [0.9858954  0.54248166 0.9547038  0.954962  ]
Pallas Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
lax.all_gather Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
Difference |Pallas - lax.all_gather| =  0.0

这里值得一提的一个细节是使用多个接收信号量。因为我们只阻塞接收设备,所以发送方仍然有可能在接收方完成处理第一个 DMA 之前发送多个正在进行的 DMA(请参阅下一节和 reduce-sum 示例,其中更详细地讨论了竞争条件)。在这种情况下,我们可能会遇到同一信号量同时用于多个 DMA 的情况。为了避免这种情况,我们分配 num_devices-1 个信号量,因此没有重复使用的风险。虽然这种竞争条件不太可能发生在如此小的内核上,但在更大的内核上,设备失去同步并可能导致静默故障的可能性更大。

高级技巧#

现在我们已经了解了如何使用远程 DMA 操作编写几个基本内核,我们将介绍用于同步和编写高效内核的更高级技巧。

同步:常规和屏障信号量#

我们在基本教程中实现的示例不需要特殊处理同步,因为所有必要的通信都写入不相交的缓冲区。但是,其他操作可能需要更复杂的通信模式,这些模式需要额外的同步原语来避免竞争条件。Pallas 提供了两个额外的原语来帮助解决这个问题:常规信号量和屏障信号量。

常规信号量#

常规信号量是用于跨多个设备同步的标准工具。信号量从根本上说是计数器——任何设备都可以递增它们,之后设备可以阻塞,直到信号量的值达到特定值(然后递减该值)。

可以在常规信号量上使用的三个主要操作是信号、等待和读取

def semaphore_signal(
    sem: Ref[SemaphoreType],
    inc: int,
    device_id: int | tuple[int, ...],
    device_id_type: DeviceIdType
) -> None:
  ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
  
def semaphore_wait(
    semaphore: Ref[SemaphoreType],
    value: int,
) -> None:
  ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
    
def semaphore_read(
    sem: Ref[SemaphoreType],
) -> jax.Array:
  ...  # Returns the current value of `sem` as an `int32[]`.

为了使用常规信号量,可以像 DMA 信号量一样分配它们,但通过指定 pltpu.SemaphoreType.REGULAR 而不是 pltpu.SemaphoreType.DMA

信号量必须在 Pallas 程序结束时为零才能成功完成。有两种错误情况可能会发生这种情况

  • 如果信号量被过度发出信号,程序将在信号量非零(>0)时结束。在这种情况下,程序将在完成时崩溃。这对于调试很有用,因为非零信号量通常意味着程序内部某个地方存在错误。

  • 如果信号量被过度等待,程序将在阻塞的 semaphore_wait 调用上挂起,同时它等待信号量被递增。在这种情况下,需要重新启动设备或程序。

屏障信号量#

屏障信号量是全局分配的信号量,用于跨整个程序同步设备并确保所有设备都已进入 Pallas 内核。

如果在更大的 XLA 程序的上下文中执行 Pallas 内核,我们需要确保所有通信的设备都已进入内核。但是,DMA 和常规信号量都是局部作用域的——只有其他已进入内核的设备才能理解它们。屏障信号量充当全局理解的信号量,无论设备当前在 XLA 程序中的哪个位置执行,都可以用于同步。

默认情况下,如果您没有指定屏障信号量,Pallas 会在程序开头自动插入一个屏障信号量。但是,编写您自己的信号量可能会更有效。屏障信号量类似于常规信号量,它们是可以通过 semaphore_signal 递增和可以通过 semaphore_wait 递减的计数器。它们是通过在内核中调用 get_barrier_semaphore() 创建的。通常,我们会在内核开始时使用一次屏障来与我们正在通信的所有设备同步。

from jax.experimental.pallas import tpu as pltpu

def example_kernel(...):
  # Use barrier semaphores at the beginning of a kernel.
  # is_start_of_kernel = ...
  # right_neighbor = ...
  # ...
  @pl.when(is_start_of_kernel)
  def _():
    barrier_sem = pltpu.get_barrier_semaphore()
    # Increment the semaphore of your right neighbor.
    pltpu.semaphore_signal(
          barrier_sem,
          device_id=right_neighbor,
          device_id_type=pltpu.DeviceIdType.LOGICAL,
    )
    # Wait until your left neighbor has incremented your semaphore
    pltpu.semaphore_wait(barrier_sem, 1)
  # ...

使用屏障信号量时,必须将 collective_id 编译器参数传递给 pallas_call 以指定正在使用的屏障信号量。TPU 具有少量固定的可用屏障信号量(通常在 20-30 个左右),因此应谨慎使用它们。为了确保正确性,仅共享相同通信模式的内核才能使用相同的 collective_id。例如,如果两个内核仅与同一网格轴上的邻居同步,则允许它们共享相同的 collective_id。但是,如果两个内核沿着不同的轴同步,则它们必须具有不同的 collective_id。否则可能会导致难以调试的竞争条件。

kernel = pl.pallas_call(
      example_kernel,
      ...,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

双缓冲#

为了避免从另一个设备也正在写入的本地 Ref 读取并创建竞争条件,一个有用的技术是“双缓冲”策略,其中我们为每个目标值分配两个 Ref。在每次迭代中,一个 Ref 将被指定为“工作”槽,另一个将被指定为“接收”槽。设备可以自由地使用工作槽进行计算,但只会将其数据复制到其邻居的接收槽中。工作槽和接收槽在每次迭代中交替,因此一旦复制完成,旧的接收槽就成为新的工作槽,反之亦然。正确使用此方案,数据永远不会从同一缓冲区读取和写入。

以下代码框架演示了如何使用双缓冲。我们在变量 iteration 中保留一个正在运行的迭代计数器,并且 working_slotreceiving_slot 在每次迭代中在 0 和 1 之间交替。 dst_ref 被分配为双缓冲区,大小为 [2, ...]。在每次迭代中,我们使用 dst_ref.at[working_slot, ...] 从工作槽读取并使用该值执行计算。同时,我们复制到邻居的 dst_ref.at[receiving_slot] 以避免覆盖其 working_slot 值。通过以这种方式构建我们的通信,可以将远程 DMA 的通信延迟与本地计算重叠,同时最大程度地降低竞争条件的风险。

def kernel(...):
  # ...
  iteration = pl.program_id(0)
  working_slot = lax.rem(iteration, 2)
  receiving_slot = 1 - working_slot
  # ...

  local_copy_op = pltpu.make_async_copy(
    src_ref=dst_ref.at[working_slot, ...],
    dst_ref=local_scratch_ref,
    sem=local_copy_sem,
  )
  local_copy_op.start()
  remote_copy_op = pltpu.make_async_remote_copy(
    src_ref=src_ref,
    dst_ref=dst_ref.at[receiving_slot, ...],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=target_device,
    device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  
  local_copy_op.wait()
  # ... do work on local_scratch while waiting for async_copy_op to finish.
  remote_copy_op.wait()

在同步方面,如果所有设备都在同一迭代上执行,则双缓冲构造有效。如果发送方设法比其接收方提前一个迭代,则其 working_slotreceiving_slot 索引将与接收方相比翻转,这意味着它可能在接收方正在读取的同时写入 working_slot。为了避免这种情况,可能需要使用信号量来同步发送方和接收方,或添加额外的缓冲槽(“三倍”、“四倍”或 N 倍缓冲)以允许额外的超前运行,但代价是需要更多内存。在我们之前的 all_gather 示例中,请注意内核包含一个具有 N 个槽的接收缓冲区,这完全避免了竞争条件。在我们的下一个内核中,我们将通过一个使用双缓冲和显式同步的示例来进行说明。

示例:全归约求和(lax.psum#

我们现在将使用双缓冲和信号量进行同步来实现全归约求和内核。对于熟悉 JAX 中集体操作的人来说,等效操作是 lax.psum。全归约是一种标准的集体操作,其目标是在数组的某个轴上进行归约,但该数组在多个设备之间进行了分片。

reduce_sum_1

在上面的示例中,我们有一个数组 [5, 2, 1, 3] 在 4 个设备之间分片。全归约求和操作将对所有值求和,并在每个设备上复制结果,从而导致结果 [11, 11, 11, 11] 在所有 4 个设备之间分片。

全归约的朴素实现是将所有所需的值收集到每个设备上,然后进行归约。但是,我们可以通过交错通信和计算来提高此实现的性能。可以将交错的单向全归约可视化如下。在每次迭代中,我们从左侧邻居接收输入值,并同时将输入传递给我们的下一个邻居,同时使用我们的本地累加器递增它。经过 N-1 次迭代后,每个设备都将在其内存中拥有完整总和的副本。

reduce_sum_2

综合以上#

以下内核演示了如何将这些原则组合成一个功能性内核。

序言(在outer_step==0时执行)首先与两个邻居启动一个屏障,以确保它们也已进入内核。它还处理所有Ref的初始化,并处理第一个到右侧邻居“工作”槽的远程复制。

主体部分假设一个值已经复制到我们的本地工作槽中,无论是来自前一次迭代还是来自序言。一个复杂因素是我们的目标缓冲区位于HBM中,但我们需要将值加载到VMEM中才能执行算术运算。因此,我们同时将工作槽值复制到我们的VMEM(receive_scratch)并将该值传递到我们右侧邻居的接收槽。一旦值被复制到我们的VMEM中,我们就可以将其累加到我们的结果中(包含在o_ref中)。

如果一个设备比其右侧邻居多运行一个循环,则可能会发生细微的竞争条件。在这种情况下,它可能会同时将数据复制到接收者的working_slot中,而接收者正在从中读取数据。为了避免这种情况,每个设备在复制到右侧邻居的dst_ref之前,都会在一个REGULAR信号量上阻塞,直到它发出信号表明它已完成从其working_slot读取数据。对于像此示例这样的小内核,这种竞争条件很少触发,但如果例如使用pltpu.delay指令人为地挂起设备,则可以明确触发它。

请注意,这不是一个最优或完全通用的内核,因为块大小必须完全适合VMEM,并且我们可以更好地交错通信和累加。我们将在后面的章节中讨论这些优化。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def all_reduce_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    copy_sem,
    remote_recv_sem,
    remote_send_sem,
    capacity_sem,
    receive_scratch,
):
  outer_step = pl.program_id(0)
  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot

  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # Initialize o_ref, acc_scratch, and hbm_scratch.
    o_ref[...] = jnp.zeros_like(o_ref)
    receive_scratch[...] = jnp.zeros_like(receive_scratch)
    initial_copy = pltpu.make_async_remote_copy(
        src_ref=x_ref,
        dst_ref=hbm_scratch.at[working_slot],
        send_sem=remote_send_sem,
        recv_sem=remote_recv_sem,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    initial_copy.start()
    initial_copy.wait()

  # Signal to our left neighbor that we are ready to receive.
  # Without this signal, our left neighbor can be >=1 iteration ahead,
  # meaning it could write into our working slot.
  pltpu.semaphore_signal(
      capacity_sem,
      inc=1,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # Copy the partial result our left neighbor sent to us into VMEM for
  # computation.
  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=receive_scratch,
      sem=copy_sem,
  )
  local_copy.start()

  # Block until our right neighbor is ready to receive.
  pltpu.semaphore_wait(capacity_sem, 1)
  # Pass the value to our right neighbor.
  remote_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=hbm_scratch.at[receiving_slot],
      send_sem=remote_send_sem,
      recv_sem=remote_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy.start()
  # Finish local copy and accumulate while remote_copy is happening.
  local_copy.wait()
  o_ref[...] += receive_scratch[...]
  # Block until remote copy finishes.
  remote_copy.wait()


out_shape = (
    jax.ShapeDtypeStruct((8, 128), jnp.float32),
    # We allocate the double-buffer as a Pallas output so that it is
    # resident in HBM.
    jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        # Our input lives in VMEM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        # Our output lives in VMEM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        # Our double-buffer lives in HBM
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices,),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3
        + [pltpu.SemaphoreType.REGULAR]  # capacity_sem
        + [pltpu.VMEM((8, 128), jnp.float32)]  # receive_scratch
    ),
)

kernel = pl.pallas_call(
    all_reduce_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
    compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

pallas_result = jax.jit(
    shard_map.shard_map(
        kernel,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]


def lax_sum(x):
  return lax.psum(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
    )
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas result =  [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result =  [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| =  1.4959369e-08

超前运行和竞争条件#

作为一个普遍的经验法则,为了最大化性能,我们希望尽可能地允许一个设备在没有同步的情况下超前于其他设备运行,而不会牺牲程序的正确性。虽然我们可以在每次迭代的开始时强制对所有设备进行屏障,但这会将程序的性能限制在每个循环中最慢的设备上。通过放松同步并允许适度的超前运行,我们可以更好地适应迭代和设备之间延迟的变化,因为在一个迭代中速度较慢的设备可以在下一个迭代中赶上来。

在我们之前编写的all-reduce内核中,我们允许设备超前运行,但与邻居相比少于一个迭代(但是,非相邻设备可能相差超过1个迭代)。要了解为什么需要信号量同步,请考虑一个设备(例如设备2)挂起并落后于其他设备的情况。RDMA没有“握手”——只有接收器在等待数据到达时被阻塞。因此,每个设备可以在被阻塞等待下一个RDMA到达之前最多超前运行一个迭代。如果我们有N个设备,这意味着最后一个设备可以比第一个设备最多超前N个迭代。

race_condition

如果没有在另一个方向添加同步(强制发送方阻塞),则设备1可能潜在地比设备2超前运行多达N个迭代(N = num_devices),在此过程中发送多个写入并覆盖值。为了在我们之前编写的all_reduce内核中解决这个问题,我们实现了一个“握手”协议,其中接收器向发送器发出信号,表示它已准备好接收,只有在那之后,发送器才会开始发出下一个RDMA。

双向通信#

在我们之前的内核中,我们沿着从左到右的环进行单向通信。但是,由于ICI连接是双向的,因此我们通过不沿相反方向(从右到左)发送值,实际上浪费了一半的总带宽。在下一个内核中,我们将演示一个示例,该示例在两个方向上进行通信以最大化ICI带宽。

示例:双向Reduce-Scatter(lax.psum_scatter#

Reduce-scatter操作是all-reduce和scatter的组合。或者,all-reduce是reduce-scatter和all-gather的组合。

下图描绘了此操作的语义。我们假设每个设备都从一组部分和开始(用字母+数字表示,例如A0)。目标是沿着一个轴(数字)进行归约,同时沿着另一个轴(字母)进行分片。

reduce_scatter_1

为了实现双向通信策略,我们将每个输入块分成两半,并为每一半指定一个方向。每个块的上半部分将从右到左传递,下半部分将从左到右传递。与我们之前all-reduce和all-gather内核的通信模式的第二个偏差是,我们还将传递累加器或部分和,并将输入保留在每个设备本地。这与之前的示例形成对比,在之前的示例中,我们传递了输入,但将累加器保留在设备本地。传递累加器更适合此问题,因为与all-reduce相比,输入中的大部分数据不是将存储在设备本地的输出的一部分。(例如,在上图中,B0C0D0最终不会存储在持有A的设备上)。

下图说明了这种通信模式,其中彩色框表示累加器(而不是输入!)。最初,累加器只是输入中包含的值。在算法的每次迭代中,我们都将从每个方向的邻居接收部分和。然后,我们计算输入的正确切片以累加到部分缓冲区中,然后将新的部分和传递给下一个邻居。经过N次迭代后,累加器将经过每个设备,这意味着它最终将保存完整的总和。

reduce_scatter_2

在内核的构建方面,我们在Pallas网格中引入了额外的phase维度,表示我们当前正在计算哪个累加器(左或右)。我们让phase=0表示累加器向左移动,phase=1表示累加器向右移动。然后,我们对这两个阶段进行流水线处理,以便在计算一个阶段的结果时,我们正在相反方向传递先前计算的值,以准备下一个阶段。例如,当我们在phase=0(左)时,我们首先开始一个DMA来将我们在前一次迭代中计算的结果传输到我们的右侧邻居(右DMA)。然后,我们累加到左缓冲区并将结果保存到HBM。然后,我们等待右DMA完成,以便它准备好用于phase=1(右)。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)

LEFT = 0
RIGHT = 1


def mod(x, n):
  return lax.rem(x + n, n)


def signal(left_or_right, semaphore):
  my_id = lax.axis_index('x')
  if left_or_right == LEFT:
    neighbor = mod(my_id - 1, num_devices)
  else:
    neighbor = mod(my_id + 1, num_devices)
  pltpu.semaphore_signal(
      semaphore,
      inc=1,
      device_id=(0, neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    local_copy_sem,
    left_recv_sem,
    left_send_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
    accum_scratch,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  # Slices can be specified using pl.ds(start, size)
  left_copy_slice = pl.ds(0, block_size[0] // 2)
  right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
  current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      # Note: Right copy is flipped with regards to slots since we are copying
      # to the next outer_step iteration.
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
    o_ref[...] = jnp.zeros_like(o_ref[...])
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  # --- Body ---
  # At the beginning of our kernel body, we start a DMA which copies
  # the result we computed in the previous phase to our neighbor.
  # This allows us to overlap the communication of sending our previous phase
  # with the computation for the current phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot, current_phase_slice],
      dst_ref=accum_scratch,
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]

    @pl.when(phase == RIGHT)
    def _():
      accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]

  local_copy = pltpu.make_async_copy(
      src_ref=accum_scratch,
      dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  # At the end of our kernel body, we wait on the DMA of the previous phase
  # to make sure the results are ready for the next phase.
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # --- Epilogue ---
  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      o_ref[left_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      o_ref[right_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32),  # output
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, block_size[0], block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
        + [
            pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
        ]  # accum_scratch
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, block_size[0], block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047  0.59993696 0.9714314  0.24692321 0.01347649
 0.01857424 0.24841607 0.86097646 0.8261659  0.9753758  0.6902338
 0.4431417  0.963323   0.3158517  0.535548  ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

嵌套远程和本地DMA流水线#

我们之前编写的all-reduce和reduce-scatter内核的一个限制是,我们通过远程DMA复制的块必须足够小,以适合我们用于累加的工作VMEM。对于某些内核,使用更大的块大小以更好地利用TPU可能更有利。例如,矩阵乘法需要\(O(N^3)\)数量级的计算操作,但只需要\(O(N^2)\)数量级的内存传输。因此,我们希望在设备之间传输的每个工作块都足够大,以便操作变得计算绑定,并且我们可以使用流水线隐藏通信成本。作为参考,TPU的VMEM(对于v4/v5代)通常在10-100MB的范围内,而HBM的范围在10-100GB之间。

为了解决这个问题,我们需要能够编写一个“内部内核”来处理“外部内核”内部的本地HBM-VMEM流水线,该“外部内核”处理设备之间更大的HBM-HBM传输的流水线。Pallas提供了一个用于使用emit_pipeline函数构建嵌套流水线的API。emit_pipeline的基本调用签名遵循标准pallas_call的调用签名,通过为输入和输出指定gridBlockSpec

def emit_pipeline(
    kernel: Callable,
    grid: tuple[int],
    in_specs: PyTree[BlockSpec] = None,
    out_specs: PyTree[BlockSpec] = None,
    should_accumulate_out: bool = False,
    dimension_semantics: tuple[GridDimensionSemantics] = None,
) -> Callable:
  ... # Returns a custom pipeline given an inner kernel and BlockSpecs.

实际上,可以将pallas_call本身简单地视为emit_pipeline的包装器。因为我们的外部内核只涉及远程HBM-HBM传输,所以我们没有使用pallas_call为HBM-VMEM传输提供的任何内置流水线。以下代码骨架演示了使用此模式的典型程序结构


def outer_kernel(...):
  # ... do work to pipeline remote HBM-HBM transfers (outer kernel)

  def inner_kernel(...):
    # ... do work (inner kernel)
  pltpu.emit_pipeline(
          inner_kernel,
          grid=inner_grid,
          in_specs=...,
          out_specs=...,
  )(inner_kernel_args)
  # ... do more work (outer kernel)

pl.pallas_call(
  outer_kernel,
  grid=outer_grid,
  in_specs=...
  out_specs=...
  scratch=inner_kernel_allocs
)

示例:具有大型HBM块的Reduce-Scatter#

在下一个示例中,我们将修改我们之前的reduce-scatter示例以利用嵌套内部流水线。请注意,reduce_scatter的通信和计算成本都与输入的大小线性相关,因此我们不一定期望随着块大小的增加,操作变得计算绑定。此示例纯粹是为了演示如何使用流水线发射器。

我们将增加外部内核的块大小,使其不适合放置在VMEM中,并在HBM中分配所有输入和输出(memory_space=TPUMemorySpace.Any)。与我们之前的内核相比,唯一的重大变化是内核的主体,其中执行累加操作。我们不是手动从HBM复制到VMEM累加器、递增,然后复制回HBM,而是使用emit_pipeline来为我们处理内存传输。累加操作在内部内核中完成,内部内核具有更小、更适合VMEM的块大小。

在我们之前的内核中,我们有以下内核主体来将数据从HBM复制到VMEM累加器、递增,然后将结果复制回HBM

local_copy = pltpu.make_async_copy(
    src_ref=hbm_scratch.at[working_slot, current_phase_slice],
    dst_ref=accum_scratch,
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
  @pl.when(phase == RIGHT)
  def _():
    accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
    src_ref=accum_scratch,
    dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()

我们的新内核将其替换为以下emit_pipeline调用

def inner_kernel(input_ref, accum_ref):
  accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
                                     in_specs=[inner_block_spec],
                                     out_specs=inner_block_spec,
                                     should_accumulate_out=True,
                                     grid=inner_grid)
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
                   hbm_scratch.at[working_slot, left_copy_slice],
    )
  @pl.when(phase == RIGHT)
  def _():
    accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
                   hbm_scratch.at[working_slot, right_copy_slice],
    )

完整的内核如下所示

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (4096, 4096)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(
        outer_block_size[0] * num_devices,
        outer_block_size[1] * num_devices,
    ),
)
input_arr = jax.device_put(input_arr, sharding)


inner_grid = (
    outer_block_size[0] // inner_block_size[0] // 2,
    outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
    index_map=lambda i, j: (i, j),
    block_shape=inner_block_size,
    memory_space=pltpu.TPUMemorySpace.ANY,
)


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    left_recv_sem,
    left_send_sem,
    copy_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
  right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
  current_phase_slice = pl.ds(
      phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
  )

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- Prologue ---
  @pl.when(is_start)
  def _():
    # Barrier with both neighbors at the start, since we will be
    # communicating with both.
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # We tell our left neighbor that it is allowed to send to the right.
    # (and vice versa for right neighbor)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # We block here until our right neighbor tells use we can send to
      # the right.
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # We block here until our left neighbor tells use we can send to
      # the left.
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  # --- Body ---
  def inner_kernel(input_ref, accum_ref):
    # We do not explicitly use += because we set should_accumulate_out=True.
    accum_ref[...] = input_ref[...]

  accum_pipeline = pltpu.emit_pipeline(
      inner_kernel,
      in_specs=[inner_block_spec],
      out_specs=inner_block_spec,
      should_accumulate_out=True,
      grid=inner_grid,
  )

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_pipeline(
          x_ref.at[left_copy_device, left_copy_slice],
          hbm_scratch.at[working_slot, left_copy_slice],
      )

    @pl.when(phase == RIGHT)
    def _():
      accum_pipeline(
          x_ref.at[right_copy_device, right_copy_slice],
          hbm_scratch.at[working_slot, right_copy_slice],
      )

  # --- Epilogue ---
  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # Store result on last iteration.
  @pl.when(last_iteration)
  def _():
    output_copy = pltpu.make_async_copy(
        src_ref=hbm_scratch.at[working_slot, current_phase_slice],
        dst_ref=o_ref.at[current_phase_slice],
        sem=copy_sem,
    )
    output_copy.start()
    output_copy.wait()

    # Clean up semaphores so that they exit with a value of 0.
    @pl.when(phase == LEFT)
    def _():
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct(
        (outer_block_size[0], outer_block_size[1]), jnp.float32
    ),
    # Shape: [working/recv, block[0], block[1]]
    jax.ShapeDtypeStruct(
        (2, outer_block_size[0], outer_block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # Capacity semaphores
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(
      num_devices, outer_block_size[0], outer_block_size[1]
  )
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# Now we compare our result to XLA.
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182  0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

最终说明#

Megacore#

某些TPU在Megacore配置中包含多个内核。在此配置中,我们的一般建议是仅从单个内核启动DMA,并且仅执行HBM-HBM传输。为此,将其中一个网格轴设置为内核数(可以通过jax.devices()[0].num_cores获取)并将dimension_semantics设置为"parallel"。然后,您可以使用core_index = pl.program_id(axis)获取该轴上的内核索引,并使用@pl.when(core_index==i)执行特定于该内核的代码。

与XLA的交互#

在本教程中,我们介绍了几个内核示例,这些示例复制了 JAX 中集体操作的功能,例如 lax.all_gatherlax.psumlax.psum_scatter。需要注意一个重要的警告,即 Pallas 内核对于 XLA 编译器来说有些“不透明”,可能会导致它错过一些通常会执行的优化。例如,XLA 可以异步调度集体操作,以便在不编写自定义内核的情况下交错通信和计算。当涉及 Pallas 内核时,这并不能保证会发生,因此务必分析程序以查看这是否是一个问题。另一个例子是,我们本教程中用于生成嵌套流水线的 emit_pipeline 函数对 XLA 编译器不可见,因此无法与相邻操作融合。

后续步骤#

对于读者来说,一些优秀的后续练习包括实现分布式矩阵乘法、实现 lax.all_to_all 以及放松同步以允许更多超前运行。