Pallas 异步操作#

背景 + 动机#

我们希望在 Pallas 中公开 API,以便显式地重叠跨多个内核的计算和通信。

XLA 异步分解#

作为动机,请考虑以下 JAX 伪代码

def f(x):
  y = ppermute(x)
  z = x + 1
  return y, z

在此函数中,我们可以在执行 x + 1 的同时执行 ppermute。这是 XLA 通过以下方式自动进行的优化

  1. ppermute 分解为 ppermute_startppermute_done 操作,它们通过 future 连接。

  2. ppermute_startppermute_done 之间调度 x + 1

从而得到以下程序

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

内核内部的异步操作#

现在假设我们没有使用 XLA 的 ppermute,而是拥有我们自己的自定义 Pallas ppermute

def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()
  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute(x):
  return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)

目前,我们无法像 XLA 那样将 ppermute 分解为 start/done 对,所以我们显式地将 x + 1 融合到内核中。

def add_one(x_ref, z_ref):
  z_ref[...] = x_ref[...] + 1

def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
  right_neighbor = ...
  descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
  descriptor.start()

  # Explicitly schedule inner kernel between start/wait
  pltpu.emit_pipeline(add_one)(x_ref, z_ref)

  descriptor.wait_send()
  descriptor.wait_recv()

def ppermute_and_add_one(x):
  return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)

目标是能够编写单独的内核来启动 ppermute 并等待它完成,以便我们可以在中间使用常规的 x + 1(或我们想要的任何计算)。这使得代码更具可读性、可维护性且不易出错。

我们如何在 (TPU 上) 实现分解的 Pallas 异步操作?#

在 Pallas 中实现分解的异步操作时,要弄清楚的主要事情是它们之间传递的 future 包含什么。具体来说,它必须包含有关在后台发生的操作的一些重要状态。

如果我们查看 Pallas 代码,我们可以看到我们需要一个“描述符”来启动和等待远程复制。我们是否可以将此描述符从 Pallas 内核中取出,然后将其传递到另一个内核中?嗯,有点。底层 TPU 硬件通过一对信号量跟踪异步操作的进度:send_sem 使我们能够等待设备何时完成向其邻居发送数据,而 recv_sem 跟踪从邻居发送到设备的数据传输。如果我们想象编写一个启动内核和一个完成内核,那么我们需要从启动传递到完成的只是信号量以及一些关于在这些信号量上等待多长时间的信息。

我们可以通过扩展 Pallas 以支持从内核返回信号量来实现这一点。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
  send_sem, recv_sem, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
  )(x)
  return send_sem, recv_sem, out

请注意,这里发生了一些微妙的事情。Pallas 正在告诉 XLA 它希望某些输出是信号量(又名同步标志),并且 XLA 会将它们视为“保留”(例如,当它们在 XLA 程序中处于活动状态时,这些同步标志不能被其他内核分配)。它们的行为类似于屏障信号量,这些屏障信号量是由 XLA 管理的保留信号量。

另一件需要注意的事情是,我们从启动内核返回输出缓冲区 out同时它正在主动复制到其中

现在我们编写执行阻塞操作的 done 内核。我们将 out 传递到内核中,以计算阻塞信号量所需的形状。

def ppermute_done_kernel(ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={0:0}
  )(out, send_sem, recv_sem)
  return out

注意:我们在此处对输出缓冲区进行 i/o 别名,以保证消费者位于 ppermute_done 的下游。

我们现在可以实现分解的集体置换。

def f(x):
  fut = ppermute_start(x)
  z = x + 1  # happens at the same time as ppermute
  y = ppermute_done(fut)
  return y, z

或者我们可以吗?

为什么不起作用#

这里存在三个剩余的问题,每个问题在某种程度上都存在于 Pallas 之外。以下是它们的高级概述。

  1. 调度 - 仅仅因为我们编写了 ppermute_start,然后是 x + 1,然后是 ppermute_done,并不能保证它们会按照该顺序发生。XLA 负责调度,因此当我们编写 JAX 程序时,我们正在设置 XLA 将遵守的数据依赖关系,但 XLA 不会遵守 JAX 中编写的特定操作顺序。

  2. 生命周期 - XLA 假设一旦值在依赖关系图中超出范围,就可以释放其内存以供其他值使用。如果我们有一个异步复制 x -> y 的操作,我们需要确保 x 在复制完成之前处于活动状态,否则我们将从垃圾内存中复制。

  3. 防御性复制 - XLA 保留创建值副本的权利。我们需要确保我们不会引入不必要的副本,以 a) 避免不必要的运行时开销,并 b) 确保正确性。

我们将逐一讨论这些问题并提出修复建议。

调度#

我们如何在 JAX 中显式强制操作以特定顺序发生?请注意,这并不是 Pallas 特有的问题,如果我们使用替代方法实现了异步操作,我们仍然会遇到这个问题。

一种方法是在 XLA 程序中引入优化屏障。优化屏障会阻止 XLA 移动其周围的操作。

这是我们原来的代码

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

XLA 可以选择在以下三个位置中的任何一个位置执行 x + 1

def f(x):
  z = x + 1
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

# OR

def f(x):
  fut = ppermute_start(x)
  y = ppermute_done(fut)
  z = x + 1
  return y, z

为了强制 x + 1ppermute 操作之间发生,我们可以使用 optimization_barrier,它在语义上是恒等函数(即 lambda x: x),但在值之间引入了显式的数据依赖关系。具体来说,如果我们使 x + 1 中使用的 x 依赖于 ppermute_start 返回的 fut,则它必须在 ppermute_start 之后发生。

我们还引入了一个依赖关系,强制输出值 y 依赖于 z

def f(x):
  fut = ppermute_start(x)
  x, fut = optimization_barrier((x, fut))  # x now depends on fut
  z = x + 1
  z, fut = optimization_barrier((z, fut)) # fut now depends on z
  y = ppermute_done(fut)
  return y, z

optimization_barrier 是一个足够好的工具,可以让我们显式地写出调度。

生命周期#

让我们再次查看我们原来的代码,并假设这些操作正在按正确的顺序发生。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

让我们看看 XLA 认为在程序的哪个点可以释放 x 的缓冲区。它将是 x 不再使用的点之后,特别是在 z = x + 1 之后。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  # XLA can free x here!
  y = ppermute_done(fut)
  return y, z

如果 XLA 在 z = x + 1 完成后释放 x,我们将遇到一个非常糟糕的问题。在 z = x + 1 之后,ppermute 可能仍在积极地将 x 复制到邻居,这意味着如果 x 被释放,ppermute 将从垃圾内存中读取!

我们如何将 x 的生命周期延长到 ppermute_done?嗯,我们可以引入数据依赖关系!我们需要稍微修改一下我们的内核才能实现这一点。

首先,我们重写 ppermute_start 以返回 x,通过内核对其进行别名。

def ppermute_start_kernel(
    in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
):
  axis_size = jax.lax.psum(1, axis_name)
  left_neighbor = jax.lax.rem(
      jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
  )
  right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
  barrier_sem = pltpu.get_barrier_semaphore()
  pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
  pltpu.semaphore_wait(barrier_sem, 1)
  pltpu.make_async_remote_copy(
      in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
  ).start()

def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
  send_sem, recv_sem, x, out = pl.pallas_call(
      functools.partial(ppermute_start_kernel, axis_name=axis_name),
      out_shape=(
          pltpu.SemaphoreType.DMA(()),
          pltpu.SemaphoreType.DMA(()),
          jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
	   jax.ShapeDtypeStruct(
              x.shape,
              dtype=x.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
      ],
      out_specs=(
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
      ),
      input_output_aliases={0:2}
  )(x)
  return send_sem, recv_sem, x, out

然后,我们让 ppermute_done 接收 x 并且不对其进行任何操作。

def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
  pltpu.make_async_copy(ref, ref, send_sem).wait()
  pltpu.make_async_copy(ref, ref, recv_sem).wait()

def ppermute_done(send_sem, recv_sem, x, out) ->Array:
  out = pl.pallas_call(
      ppermute_done_kernel,
      out_shape=(
          jax.ShapeDtypeStruct(
              out.shape,
              dtype=out.dtype,
          ),
      ),
      in_specs=[
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.ANY),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
          pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
      ],
      out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
      input_output_aliases={1:0}
  )(x, out, send_sem, recv_sem)
  return out

现在当我们编写

def f(x):
  *sems, x ,out = ppermute_start(x)
  z = x + 1
  y = ppermute_done(*sems, x, out)
  return y, z

XLA 无法再释放 x,因为它现在是 ppermute_done 的输入!这意味着 x 的生命周期与 ppermute 绑定,这段代码现在是正确的。

防御性复制#

XLA 在其缓冲区分配阶段,会分析哪些缓冲区彼此别名,并在某个操作对其输入之一进行别名,但该输入并非最终使用者时插入复制。

背景#

这是一个简单的例子。假设我们有一个操作 add_one_inplace,它接收一个数组并加一,但承诺原地操作。

以下代码是合法的。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)  return y

但是,如果 x 也有一个单独的消费者,则程序可能无法正确执行。

def f():
  x = jnp.arange(...)
  y = add_one_inplace(x)
  return y, x * 2 # another x consumer!

这是因为 x * 2 操作原始的 x,但 add_one_inplace 会覆盖 x 中的值。x * 2 需要确保读取 x 的原始值,而不是将其加 1 后的值。XLA 注意到这一点,并插入一个 copy 操作(在语义上是恒等操作,但输入和输出缓冲区将不同)。

def f(x):
  x2 = copy(x)
  y = add_one_inplace(x2)
  return y, x * 2

XLA 中的此阶段通过强制使用 copy 操作来有效地进行异地更新,从而确保在执行原地更新的操作时保持正确性。

带有下游操作的复制#

让我们重新审视一下在进行 ppermute 时加 1 的示例。

def f(x):
  fut = ppermute_start(x)
  z = x + 1
  y = ppermute_done(fut)
  return y, z

如果我们把 future 展开成它的组成部分,我们将看到别名模式

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

我们知道 xppermute_start 中保持不变(即,xx2 相同),但 XLA 不知道。实际上,它看起来像我们给 XLA 的 add_one_inplace 示例,它保守地假设 ppermute_start 修改了 x,而 x2 是新的别名结果。因此,当我们执行 z = x + 1 时,我们会遇到原始缓冲区的消费者。因此,XLA 引入了一个复制操作!

def f(x):
  x2 = copy(x)
  *sems, x2, y = ppermute_start(x2)
  z = x + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

这个复制操作是不必要的,因为我们知道 x2x 相比没有改变。为了删除这个复制操作,我们需要某种机制来通知 XLA 我们只是在转发一个值。但是,在没有这种机制的情况下,我们可以稍微重写一下程序,以显式使用 x2 而不是 x

def f(x):
  *sems, x2, y = ppermute_start(x)
  z = x2 + 1
  y = ppermute_done((*sems, x2, y))
  return y, z

现在,XLA 没有看到 x 的单独使用者,因此不会引入复制操作。但是,这带来了一个主要的缺点,即它迫使我们解包来自 ppermute_start 的 future。它将生命周期问题与复制问题联系起来。

循环别名#

让我们考虑一个稍微高级的例子。让我们实现一个使用 while_loopppermute 在环中发送值的函数。

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x)

fori_loop 的一个实现细节是,输入和输出缓冲区会自动相互别名。请注意,我们正在 ppermute_startppermute_done 操作中设置一些额外的别名。让我们通过对程序中的每个值进行着色来运行我们自己的“缓冲区分配”,以确定我们需要多少个唯一缓冲区。

首先,我们将解包具有别名 xout 缓冲区的 fut 元组。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done(*sems, x, y)
    return y
  return fori_loop(0, 8, body, x)

现在,根据分配给它们的唯一缓冲区对每个值进行着色。我们有来自 fori_loop 的输入/输出别名,来自 ppermute_startx 别名和来自 ppermute_doney 别名。

def f(x):
  def body(i, x):
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

如果您运行别名分析,您会发现所有缓冲区都已被着色为相同!直观地说,这是有问题的,因为如果我们正在执行 ppermute 的循环,我们就不能写入我们正在发送到的同一个缓冲区。我们通常需要一个额外的(即“双重”)缓冲区来接收,然后通常我们会在下一次迭代时切换发送/接收缓冲区。XLA 在实践中所做的是,它会观察到缓冲区重用并防御性地插入一个复制操作。

def f(x):
  def body(i, x):
    x = copy(x)
    *sems, x, y = ppermute_start(x)
    y = ppermute_done((*sems, x, y))
    return y
  return fori_loop(0, 8, body, x)

此复制操作意味着 xy 不再彼此别名,并且程序将是正确的。但是,我们需要这个复制操作吗?我们如何引入双重缓冲区以避免每次迭代都进行昂贵的复制?答案是展开!

我们将手动展开我们的代码。

def f(x):
  def body(i, x):
    *sems, x, x2 = ppermute_start(x)
    x2 = ppermute_done((*sems, x, x2))
    
    *sems, x2, y = ppermute_start(x2)
    y = ppermute_done((*sems, x2, y))
    return y
  return fori_loop(0, 4, body, x)

现在,如果我们运行相同的别名分析,我们会发现缓冲区不再彼此别名,并且我们不需要插入防御性复制来确保正确性。

因此,删除这些复制操作的简单解决方案是使用 fori_loop,其中 unroll >= 2

def f(x):
  def body(i, x):
    fut = ppermute_start(x)
    y = ppermute_done(fut)
    return y
  return fori_loop(0, 8, body, x, unroll=2)

这足以在没有额外复制的情况下实现此循环!

跨循环边界传递 future#

现在让我们看一个更高级的例子。我们将实现与之前相同的程序,但会错开循环,我们在循环之前的序言中开始 ppermute,并在循环开始时等待 ppermute

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, fut)
  return ppermute_done(fut)

在此示例中,我们传递的是 future 值,而不是将值 x 从一个循环传递到另一个循环。

让我们再次解包 future,看看发生了什么。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

因此,我们正在显式地将信号量、输入缓冲区和目标输出缓冲区作为循环携带。如果我们现在运行别名分析会发生什么?好吧,我们会遇到与上一节中相同的别名问题,其中 xout 将彼此别名。XLA 将引入一个复制操作。

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    *sems, x, out = fut
    out = copy(out)
    x = ppermute_done((*sems, x, out))
    (*sems, x, out) = ppermute_start(x)
    return (*sems, x, out)
  (*sems, x, out) = fori_loop(0, 7, body, x)
  return ppermute_done((*sems, x, out))

在这种情况下,我们在 out 上插入了一个复制操作。但是,这是一个非常糟糕的情况,因为 out 正在被积极地复制到!即使我们在 x 上插入一个复制操作,我们也会遇到问题,因为 x 的生命周期不会延伸到 ppermute_done。这非常非常糟糕!我们不仅会得到复制操作,还会得到不正确的结果!

正如我们之前观察到的,解决方案是通过展开来避免所有缓冲区别名,从而避免复制操作。因此,如果我们这样做

def f(x):
  fut = ppermute_start(x)
  def body(i, fut):
    x = ppermute_done(fut)
    fut = ppermute_start(x)
    return fut
  fut = fori_loop(0, 7, body, x, unroll=2)
  return ppermute_done(fut)

我们的程序现在应该是正确的。

总结#

因此,我们总结了一些经验法则

  1. 如果我们的操作依赖于 ppermute 的输入值,则解包 future 以使用别名值而不是原始值。

  2. 在循环体中执行 ppermute 时,使用 unroll >= 2

让我们将所有内容组合到一个函数中,该函数在循环中执行 ppermute 并累积结果。

def f(x):
  out = jnp.zeros_like(x)
  fut = (*sems, x, out) = ppermute_start(x)
  out = out + x
  def body(i, carry):
    out, fut = carry
    x = ppermute_done(fut)
    fut = (*sems, x, out) = ppermute_start(x)
    out = out + x
    return out, fut
  out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
  return out, ppermute_done(fut)

请注意,在此示例中,我们不需要 optimization_barrier,因为循环边界充当调度屏障,从而将 startdone 分开。

就这样,我们完成了!这将是 Pallas 中执行异步操作的官方 API。谢谢大家!任务完成!

或者,是这样吗?

状态的复仇#

虽然我们似乎通过使用一些巧妙的技巧解决了复制和不正确的问题,但我们仍然处于尴尬的境地。这个 API 功能强大,但有很多陷阱和注意事项。可能还有更多的边缘情况需要我们处理,这些情况甚至需要深入了解 XLA 才能预测或理解。我们应该发布这样的 API 吗?或者是否有替代方案?

好吧,答案可能一直就在我们面前。

让我们再运行一次整个练习,除了,让我们编写有状态的版本。这意味着我们的每个自定义异步操作现在都在 Ref 而不是值上进行操作。

def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
  ...

def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
  ...

假设我们可以在 Pallas 中实现这些,看看我们的新程序会是什么样子。让我们从基本的集体置换开始

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

它比我们最初的基于值的版本稍微冗长一些,但它有一些关键的区别。首先,我们创建了一个“空的” Ref 来接收 ppermute 的结果,这与基于值的版本不同,后者会为我们创建一个值。一个巧妙之处在于,x_ref 的生命周期在这里很明确:它一直存在到 ppermute_done_stateful。我们不需要像以前那样将 x 值“偷偷地”放入操作中。

当我们尝试在 start/done 之间添加一个操作时,另一个区别变得更加明显。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  fut = ppermute_start_stateful(x_ref, y_ref)
  x_ref[...] += 1
  ppermute_done_stateful(*fut, x_ref, y_ref)
  return y_ref[...]

之前,我们遇到了调度歧义,其中 XLA 可以相对于 ppermute 重新排序加法操作。使用有状态语义,我们实际上添加了一个排序约束! x_ref[...] += 1 修改了 x_ref,因此它不能相对于 ppermute_done_stateful 移动。JAX 可以将这些调度约束作为降级到 HLO 的一部分注入。

当我们尝试循环示例时,最终的关键区别就显现出来了。

def f(x):
  x_ref = make_ref(x)
  y_ref = make_ref(zeros_like(x))
  def body(i, _):
    fut = ppermute_start_stateful(x_ref, y_ref)
    ppermute_done_stateful(*fut, x_ref, y_ref)
    # Now switch to y_ref -> x_ref
    fut = ppermute_start_stateful(y_ref, x_ref)
    ppermute_done_stateful(*fut, y_ref, x_ref)
  fori_loop(0, 8 // 2, body, None)
  return x_ref[...]

由于我们需要一个单独的缓冲区来接收 ppermute,我们被迫以展开循环的方式编写代码!没有办法在 XLA 中编写需要复制的版本,因为这会涉及到从 Ref 发送到自身的 ppermute,这实际上没有意义。

为了在没有手动展开的情况下处理这个问题,我们将创建一个具有前导 2 维度的临时缓冲区,该缓冲区充当迭代之间的发送/接收目标,并在每次迭代时切换。这与我们在编写手动重叠内核时在 Pallas 内核中内部使用的模式相同。

这里的认识是,有状态迫使我们更早地处理与值语义相关的许多问题。我们将它们定义排除!

  1. 调度 - 将 Ref 作为输入的有状态操作强制对我们的程序进行排序。请注意,这将对同一个 Ref 上的操作进行相互排序。我们可能还需要一个 opt_barrier_stateful 来强制执行更多的排序约束。

  2. 生命周期 - Ref 的生命周期可以通过 run_state 来限定作用域,或者可以是有状态操作的输入。

  3. 防御性复制 - 使用 Ref 迫使我们“手动”处理缓冲区分配,并且降级可以确保别名工作正常以避免任何复制。

另一个重要的基本限制是,我们最终会分阶段输出一个 HLO 程序,其中活动缓冲区和信号量表示为数组值类型。XLA 不保证这些中间值的缓冲区生命周期或它们所在的内存空间。因此,即使 Pallas 内核正在主动复制数组值,XLA 也可能会复制数组值。 这在 HLO 中很容易验证,但是使用自定义调用来表示 HLO 中的异步操作是一个尖锐的边缘。

结论#

我们已经讨论了在 Pallas 和 JAX 中处理异步操作时遇到的一些棘手挑战。 Ref 似乎是表示这些操作的一种有前途的方法,它可以规避基于值语义出现的一些问题。然而,一个缺点是它将有状态的 JAX 放在首位,这是我们除了 Pallas 之外还没有做过的。值得思考的是,我们是否应该向用户介绍有状态的操作,或者提供一个更危险的 API。我们也不知道我们想要做的一切是否也可以通过 Ref 来表达。我们还应该集思广益,寻找状态的替代方案,以充实设计空间。例如,如果 XLA 提供一个尊重生命周期的第一类 futures API,并且它可以自动执行诸如使用 futures 在其中进行双缓冲循环之类的操作会怎么样?这可能是一个可行的替代方案,但权衡之处在于将更多的控制权交给编译器,而不是用户明确控制。