在 JAX 中排序副作用#

sharadmv@ 2022 年 5 月 9 日

概述#

当我们编写 JAX 代码时,我们通常可以假装我们在编写单线程、急切执行的 Python,即使在幕后,JAX 及其运行时可能会在后台异步执行它。只要我们编写纯(无副作用)代码,这些性能优化通常对我们来说是不可见的,并且不会干扰我们的单线程心理模型。异步执行很棒 - 我们得到了高性能的并行代码,而无需考虑它!

但是,在存在副作用的情况下,这种错觉就开始瓦解,我们心理模型中的裂缝开始显现。具体来说,当我们思考副作用发生的顺序时,这些差异就会显现出来。

在本设计说明中,我们探讨了 JAX 执行模型与副作用排序之间的交互。我们还提供了一种方法来强制执行效果的“单线程”排序。

背景#

当我们编写以下 Python 代码时

def f():
  print("hello")
  return 2
def g():
  print("world")
  return 3
f()
g()

我们期望 "hello""world" 之前打印。这可能看起来很明显,但请考虑以下 JAX 代码

@partial(jax.jit, device=<device 0>)
def f():
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  return 3
f()
g()

在许多情况下,JAX 将并行执行fg,并将计算分派到不同的线程 - g实际上可能在f之前执行。并行执行是一个不错的性能优化,尤其是在将数据复制到设备或从设备复制数据很昂贵的情况下(有关更多详细信息,请参阅异步分派说明)。在实践中,我们通常不必考虑异步分派,因为我们正在编写纯函数,并且只关心函数的输入和输出 - 我们将自然地阻塞在未来值上。

但是,现在假设我们有一个jax.print函数,它在 JIT 编译的 JAX 函数内部工作(host_callback.id_print是此类函数的一个示例)。让我们回到之前的示例,但这次添加了打印操作。

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return 3
f()
g()

由于异步分派,我们实际上可能会看到"world""hello"之前打印。打印副作用的重新排序打破了单线程执行模型的假象。

另一个可能导致“显示”乱序执行的副作用示例是当我们编译 JAX 程序时。考虑以下 JAX 代码

@jax.jit
def f(x):
  jax.print("hello")
  jax.print("world")
  return x

即使在 Python 中,我们编写了"hello"打印在"world"打印之前,但像 XLA 这样的编译器可以自由地重新排序它们,因为打印之间没有明确的数据依赖关系。

动机#

我们希望支持“有序”的副作用。当我们说有序时,我们的意思是副作用发生的顺序与我们执行单线程 Python 程序时的顺序相同。这是我们的主要愿望。在存在显式并行性(如pmap或用户线程)的情况下,我们不需要保持这种行为,但至少如果用户没有明确地请求并行性,我们希望保留单线程顺序。

在我们深入研究之前,让我们先退一步问问自己,为了性能而重新排序副作用是否可以接受,反之,我们是否需要完全强制对副作用进行排序?在某些情况下,我们不需要排序。也许某些副作用不会对 JAX 程序的性能产生负面影响。但是,对于其他副作用,我们可能希望强制单线程程序顺序,以便用户不会遇到违反直觉的行为。考虑一个记录副作用。

@jax.jit
def f(x, y):
  log_value(x)
  log_value(y)
f(1, 2)

如果log正在修改一个全局列表,我们可能希望我们在添加y之前添加x。对于更严格的副作用,我们可能希望选择对副作用进行排序。

强制有序副作用#

我们用来强制计算顺序的主要工具是数据依赖。简而言之,如果函数g有一个输入是函数f的输出,那么f必须在g之前执行。

但是,我们可能有一些副作用(如打印)完全没有输入,因此我们无法简单地对它们进行排序。因此,我们使用令牌作为一种将人工数据依赖关系注入计算的方法。

什么是令牌?令牌只是一个虚拟值,可以贯穿计算过程。通过将同一个令牌贯穿多个计算过程,我们强制它们必须按照一定的顺序发生。让我们以之前的打印示例为例,看看它在混合令牌的情况下会是什么样子

@jax.jit
def f(token, x):
  token = jax.print(token, "hello")
  token = jax.print(token, "world")
  return token, x

如果我们重写jax.print以接收并返回一个令牌,那么我们就对这两个打印进行了排序,因为第二个打印的输入依赖于第一个打印的输出。实际上,token的值可以是任何东西,但我们将在实践中看到,令牌对用户是不可见的。

运行时令牌与编译器令牌#

在这里,我们将开始讨论实现细节。在实践中,我们将需要两种类型的令牌来对副作用进行排序:一种用于上述每个重新排序来源。我们将需要运行时令牌来对异步分派副作用计算进行排序,并且我们将需要编译器令牌来对计算内的副作用进行排序。

在实践中,我们的计算将被重写成这样

@jax.jit
def f(runtime_token, x):
  compiler_token = new_compiler_token()
  compiler_token = jax.print(compiler_token, "hello")
  compiler_token = jax.print(compiler_token, "world")
  return runtime_token, x

注意,运行时令牌仅在 JIT 边界使用,而编译器令牌仅在编译代码内使用。编译器令牌是在“降低”过程中创建的(我们将 Python 代码转换为更低级的表示形式,如 HLO 或 StableHLO),但运行时令牌需要在 Python 中进行管理,因为它们被贯穿在 JIT 编译的函数中。

此外,请注意,运行时令牌与编译器令牌“断开连接”,这意味着它们之间没有数据依赖关系。这可能很危险,因为如果我们丢失了两个分派函数调用体之间的数据依赖关系。但是,如果我们假设“严格执行” - 即一个分派函数只有在其所有输入准备就绪并且其所有输出同时准备就绪时才会开始执行 - 那么我们就可以安全地创建一个新的编译器令牌并返回一个不依赖于输出的运行时令牌。

管理运行时令牌#

为了代表用户管理运行时令牌,我们需要挂接到 JAX 的分派机制。每当我们调用一个 JIT 编译的函数时,我们最终会到达一个类似于这样的函数

def _execute(compiled_computation, *args):
  outputs = compiled_computation.execute(*args)
  return outputs

此时,我们需要将运行时令牌“注入”到计算中,并将它们从计算的输出中“提取”出来

def _execute(compiled_computation, *args):
  runtime_token = get_runtime_token() # Grab global token
  runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_runtime_token(runtime_token) # Update global token
  return outputs

runtime_token到底是什么?嗯,我们需要能够将它传递给compiled_computation,这意味着它需要是一种数组(目前是这样,因为编译后的 JAX 代码内部和外部没有共享的令牌表示)。实际上,我们可以使用一个(0,)形状的数组来最大限度地减少开销。

我们还需要考虑多设备使用情况,例如第一个示例,我们首先在设备 0 上调用一个 JIT 编译的函数,然后在设备 1 上调用一个函数。在这种情况下,我们还需要将从第一个计算返回的运行时令牌(它位于设备 0 上)复制到设备 1,以便我们可以将它传递给第二个计算。如果两个后续计算共享同一个设备,则不需要进行此复制。

添加编译器令牌#

当我们将 Python 代码降低为 HLO 或 StableHLO 时,我们需要在计算开始时创建一个令牌,并确保当我们需要对副作用计算进行排序时它们可用。副作用计算将接收令牌作为输入,并将它作为输出返回。

这种令牌贯穿的实现涉及升级 JAX 降低机制以自动执行这种簿记工作。主要挑战包括处理高阶原语(如调用原语和控制流原语)。在本设计说明中,我们不会详细介绍如何处理这些原语。

阻塞输出令牌#

为副作用计算添加运行时令牌和编译器令牌支持对于排序很重要,但令牌还有另一个微妙的用例,即阻塞副作用计算。即使我们不想对副作用计算进行排序,我们可能仍然希望等待它完成。目前,我们有jax.block_until_ready,它会等待一个未来值的结果准备好。但是,对于副作用计算,我们可能有一些函数没有返回值,但仍然在执行副作用。这里举一个简单的例子

@jax.jit
def f():
  jax.print("hello world")
  return
f() # Executed asynchronously

这个编译后的计算没有显式输入,也没有显式输出。如果它是一个有序的打印副作用,我们可以在返回的运行时令牌上阻塞。但是,当这是一个无序计算时,我们不执行任何令牌贯穿操作。当我们没有输出值可以调用block_until_ready时,我们如何等待f()完成执行?嗯,我们可以应用相同的令牌策略,只是我们只返回运行时令牌,而不是将它们作为输入。这将给我们一个值来阻塞,这个值只有在f()完成执行后才会准备好。我们将这些令牌称为输出令牌。我们最终得到一个类似于这样的函数

@jax.jit
def f():
  jax.print("hello world")
  return new_runtime_token()
f() # Executed asynchronously

在幕后,我们将以与我们管理运行时令牌相同的方式管理输出令牌,但提供一种方法让用户在当前的输出令牌集上阻塞。与运行时令牌不同,输出令牌需要是特定于设备的。考虑一个单设备用例

@jax.jit
def f():
  jax.print("hello")

@jax.jit
def g():
  jax.print("world")

f()
g()

由于f()g()在同一个设备上执行,因此在g()的输出令牌上阻塞实际上是在阻塞f(),因为(截至目前!)JAX 运行时不会交错在同一个设备上执行的计算。当然,如果这种情况发生变化,我们将不得不修改整个设计。

但是,考虑两个设备用例

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")

f()
g()

在这里,我们不想明确地对f()g()进行排序,但希望等待它们都完成。我们将需要一个输出令牌用于f(),另一个用于g(),并且我们将在这两个令牌上阻塞

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return new_runtime_token()

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return new_runtime_token()

t0 = f()
t1 = g()
block_until_ready((t0, t1))

因此,我们将需要一个特定于设备的输出令牌,以便我们可以避免对不同设备上的计算进行排序,同时提供阻塞副作用计算的能力。我们最终得到以下(近似的)对 JAX 分派机制的更改

def _execute(compiled_computation, *args):
  output_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_output_token(output_token, compiled_computation.device)
  return outputs

我们还需要公开一个在输出令牌上阻塞的函数

def effects_barrier():
  output_token.block_until_ready()

请注意,在输出令牌上阻塞可能并不常见,因为大多数 JAX 计算都会返回一个值来阻塞。但是,输出令牌对于测试和分析很有用,并且值得支持,以便我们有一个一致且连贯的副作用系统。

更多细节#

  • 所有上述令牌管理基础设施将是线程本地的。这意味着每个用户线程将拥有自己的独立运行时令牌流。排序只承诺在用户线程级别进行。

  • 在实践中,我们每个副作用都拥有一个运行时标记。相同副作用的不同实例将被排序。这样做的目的是避免对可能没有相互关系的副作用计算进行排序。从技术上讲,这与我们最初的目标相悖,即强制执行单线程 Python 程序排序,但这是一种折衷方案,可以通过同时拥有“特定于副作用”的标记和“全局”标记来进行调整。