JAX 中的副作用排序#
sharadmv@ 2022 年 5 月 9 日
概述#
当我们编写 JAX 代码时,通常可以假装我们正在编写单线程、急切执行的 Python,即使在底层,JAX 及其运行时可能会在后台异步执行它。只要我们编写纯(无副作用)代码,这些性能优化通常对我们是不可见的,并且不会干扰我们的单线程心理模型。异步执行非常棒——我们可以获得高性能的并行代码,而无需考虑它!
然而,在出现副作用的情况下,这种幻觉开始瓦解,并且我们心理模型中的裂缝开始显现。具体来说,当我们考虑副作用发生的顺序时,这些差异就会显现出来。
在本设计说明中,我们探讨了 JAX 的执行模型与副作用排序之间的相互作用。我们还提供了一种强制执行效果的“单线程”排序的方法。
背景#
当我们编写以下 Python 代码时
def f():
print("hello")
return 2
def g():
print("world")
return 3
f()
g()
我们期望先打印 "hello"
,然后再打印 "world"
。这似乎是显而易见的,但请考虑以下 JAX 代码
@partial(jax.jit, device=<device 0>)
def f():
return 2
@partial(jax.jit, device=<device 1>)
def g():
return 3
f()
g()
在许多情况下,JAX 会并行执行 f
和 g
,将计算分派到不同的线程 —— g
实际上可能在 f
之前执行。并行执行是一种很好的性能优化,特别是当设备之间的复制开销很大时(有关更多详细信息,请参阅异步调度说明)。然而,在实践中,我们通常不需要考虑异步调度,因为我们正在编写纯函数,并且只关心函数的输入和输出 —— 我们自然会阻塞未来值。
但是,现在想象一下我们有一个可以在 JIT 化的 JAX 函数内部工作的 jax.print
函数(host_callback.id_print
就是一个例子)。让我们回到前面的例子,但混合使用打印。
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return 2
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return 3
f()
g()
由于异步调度,我们实际上可能会看到先打印 "world"
,然后再打印 "hello"
。打印副作用的重新排序打破了单线程执行模型的假象。
当编译 JAX 程序时,副作用“揭示”乱序执行的另一个例子。考虑以下 JAX 代码
@jax.jit
def f(x):
jax.print("hello")
jax.print("world")
return x
即使在 Python 中,我们在 "world"
打印之前编写了 "hello"
打印,像 XLA 这样的编译器也可以自由地重新排序它们,因为打印之间没有显式的数据依赖关系。
动机#
我们希望支持“有序”效应。当我们说有序时,我们的意思是这些效应的发生顺序与我们执行单线程 Python 程序时的顺序相同。这是我们的主要目标。在存在像 pmap
或用户线程这样的显式并行性时,我们不需要维持这种行为,但至少如果用户没有明确请求并行性,我们希望保留单线程排序。
在我们深入研究之前,让我们先退一步问问自己,为了性能而重新排序效果是否可以,反之,我们是否需要强制执行效果的排序?在某些情况下,我们不需要排序。也许某些副作用不应该对 JAX 程序的性能产生不利影响。但是,对于其他副作用,我们可能希望强制执行单线程程序顺序,这样用户就不会得到违反直觉的行为。考虑一个日志记录效果。
@jax.jit
def f(x, y):
log_value(x)
log_value(y)
f(1, 2)
如果 log
正在改变全局列表,我们可能会期望在添加 y
之前添加 x
。对于更严格的效果,我们可能需要选择对效果进行排序。
强制执行有序效果#
我们用来强制执行计算顺序的主要工具是数据依赖。简而言之,如果函数 g
的输入是函数 f
的输出,则 f
必须在 g
之前执行。
但是,我们可能有一些像打印这样的副作用,它们根本没有输入,因此我们不能简单地对它们进行排序。因此,我们使用令牌作为将人为数据依赖性注入计算的方法。
什么是令牌?令牌只是一个虚拟值,可以在计算中传递和取出。通过将相同的令牌传递到多个计算中,我们强制它们必须按特定顺序发生。让我们以前面的打印示例为例,看看它在混合中使用令牌时会是什么样子
@jax.jit
def f(token, x):
token = jax.print(token, "hello")
token = jax.print(token, "world")
return token, x
如果我们重写 jax.print
以接收和返回令牌,我们现在已经对两个打印进行了排序,因为第二个打印的输入取决于第一个打印的输出。token
的实际值实际上可以是任何值,但我们在实践中会看到令牌对用户是不可见的。
运行时令牌与编译器令牌#
在这里,我们将实际开始讨论实现细节。在实践中,我们需要两种不同类型的令牌来对效果进行排序:一种用于上述每种重新排序的来源。我们需要运行时令牌来对异步调度的副作用计算进行排序,并且我们需要编译器令牌来对计算内的效果进行排序。
在实践中,我们的计算将被重写成这样
@jax.jit
def f(runtime_token, x):
compiler_token = new_compiler_token()
compiler_token = jax.print(compiler_token, "hello")
compiler_token = jax.print(compiler_token, "world")
return runtime_token, x
请注意,运行时令牌仅在 JIT 边界使用,而编译器令牌仅在编译的代码中使用。编译器令牌在“降级”期间创建(我们将 Python 代码转换为较低级别的表示形式,例如 HLO 或 StableHLO),但运行时令牌需要在 Python 中管理,因为它们正在 JIT 化的函数中传递和取出。
此外,请注意,运行时令牌与编译器令牌“断开连接”,这意味着它们之间没有数据依赖关系。这可能很危险,因为如果我们将失去两个调度的函数调用主体之间的数据依赖关系。但是,如果我们假设“严格执行”——即,只有当所有输入都准备好时,才会开始执行调度的函数,并且所有输出将同时准备好——我们可以安全地创建一个新的编译器令牌并返回一个不依赖于输出的运行时令牌。
管理运行时令牌#
为了代表用户管理运行时令牌,我们需要挂钩到 JAX 的调度机制中。每当我们调用 JIT 化的函数时,我们最终都会归结为一个看起来像这样的函数
def _execute(compiled_computation, *args):
outputs = compiled_computation.execute(*args)
return outputs
此时,我们需要将运行时令牌“注入”到计算中,并从计算的输出中“提取”它们
def _execute(compiled_computation, *args):
runtime_token = get_runtime_token() # Grab global token
runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_runtime_token(runtime_token) # Update global token
return outputs
runtime_token
到底是什么?好吧,我们需要能够将其传递到 compiled_computation
中,这意味着它需要是某种数组(目前,因为编译的 JAX 代码的内部和外部没有共享的令牌表示形式)。在实践中,我们可以使用形状为 (0,)
的数组来最小化开销。
我们还需要考虑多设备使用情况,例如,第一个示例中,我们首先在设备 0 上调用 JIT 化的函数,然后在设备 1 上调用 JIT 化的函数。在这种情况下,我们还需要将从第一个计算返回的运行时令牌(位于设备 0 上)复制到设备 1,以便我们可以将其传递到第二个计算中。如果两个后续计算共享同一设备,则此复制不是必需的。
添加编译器令牌#
当我们将 Python 代码降级为 HLO 或 StableHLO 时,我们需要在计算开始时创建一个令牌,并确保当我们有需要排序的副作用计算时它们可用。副作用计算将令牌作为输入并将其作为输出返回。
此令牌线程的实现涉及升级 JAX 降级机制以自动执行此簿记。主要挑战包括处理高阶原语(例如调用原语和控制流原语)。我们不会在此设计说明中详细介绍如何处理这些原语。
阻塞输出令牌#
为副作用计算添加对运行时和编译器令牌的支持对于排序很重要,但令牌还有另一个微妙的用例,即阻塞副作用计算。即使我们不希望副作用计算是有序的,我们可能仍然希望等待其完成。目前,我们有 jax.block_until_ready
,它会等待直到未来值准备好结果。但是,对于副作用计算,我们可能有不具有返回值但仍在执行副作用的函数。以下面的简单示例为例
@jax.jit
def f():
jax.print("hello world")
return
f() # Executed asynchronously
这个编译的计算不接受显式输入并且没有显式输出。如果它是一个有序的打印效果,我们可以在返回的运行时令牌上阻塞。但是,当这是一个无序计算时,我们不进行任何令牌线程处理。当我们没有输出值来调用 block_until_ready
时,我们如何等待 f()
完成执行?好吧,我们可以应用相同的令牌策略,只不过我们只返回运行时令牌,而不将它们作为输入。这将为我们提供一个可以阻塞的值,该值只有在 f()
完成执行后才会准备就绪。我们将这些令牌称为输出令牌。我们最终得到一个看起来像这样的函数
@jax.jit
def f():
jax.print("hello world")
return new_runtime_token()
f() # Executed asynchronously
在底层,我们将以管理运行时令牌的相同方式管理输出令牌,但为用户提供一种阻塞当前输出令牌集的方法。与运行时令牌不同,输出令牌需要是特定于设备的。考虑一个单设备用例
@jax.jit
def f():
jax.print("hello")
@jax.jit
def g():
jax.print("world")
f()
g()
由于 f()
和 g()
在同一设备上执行,因此阻塞 g()
的输出令牌实际上会阻塞 f()
,因为(到目前为止!),JAX 运行时不会交错在同一设备上执行的计算。如果这种情况发生变化,我们将不得不修改整个设计。
但是,请考虑双设备用例
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
f()
g()
这里我们不想显式地对 f()
和 g()
进行排序,而是希望等待它们都完成。我们需要为 f()
提供一个输出令牌,为 g()
提供一个输出令牌,并且我们将阻塞这两个令牌。
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return new_runtime_token()
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return new_runtime_token()
t0 = f()
t1 = g()
block_until_ready((t0, t1))
因此,我们需要一个每个设备的输出令牌,这样我们就可以避免在不同设备上对计算进行排序,同时提供阻塞副作用计算的能力。我们最终对 JAX 调度机制进行了以下(近似)更改
def _execute(compiled_computation, *args):
output_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_output_token(output_token, compiled_computation.device)
return outputs
我们还需要公开一个函数来阻塞输出令牌
def effects_barrier():
output_token.block_until_ready()
请注意,阻塞输出令牌可能并不常见,因为大多数 JAX 计算都会返回一个值来阻塞。但是,输出令牌对于测试和分析很有帮助,并且最好支持它,以便我们拥有一个一致且有凝聚力的效果系统。
更多细节#
所有上述令牌管理基础设施都将是线程本地的。这意味着每个用户线程都将拥有自己的独立运行时令牌流。仅在用户线程级别保证排序。
实际上,我们每个效果都有一个运行时令牌。该效果的不同实例将按顺序执行。这是为了避免对可能彼此没有关系的有效计算进行排序。从技术上讲,这违背了我们最初强制执行单线程 Python 程序排序的目标,但这可以通过同时拥有“效果”特定令牌和“全局”令牌来调节的权衡。