Just-in-time 编译#

在本节中,我们将进一步探讨 JAX 的工作原理以及如何使其具有高性能。我们将讨论 jax.jit() 变换,它将执行 JAX Python 函数的 *Just In Time* (JIT) 编译,以便它能够在 XLA 中高效地执行。

JAX 变换的工作原理#

在上一节中,我们讨论了 JAX 允许我们变换 Python 函数。JAX 通过将每个函数简化为一系列 原语 操作来实现这一点,每个操作代表一个基本的计算单元。

查看函数背后的一系列原语的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

文档的 理解 Jaxpr 部分提供了有关上述输出含义的更多信息。

重要的是,请注意 jaxpr 没有捕获函数中存在的副作用:其中没有任何内容对应于global_list.append(x)。这是一个特性,而不是错误:JAX 变换旨在理解无副作用(也称为函数式纯)代码。如果纯函数副作用是陌生的术语,则在🔪 JAX - The Sharp Bits 🔪: 纯函数中对此进行了更详细的解释。

不纯函数很危险,因为在 JAX 变换下,它们很可能无法按预期工作;它们可能会静默失败,或者产生令人惊讶的下游错误,例如泄漏的 Tracers。此外,JAX 通常无法检测到副作用的存在。(如果需要调试打印,请使用jax.debug.print()。要以性能为代价表达通用副作用,请参阅jax.experimental.io_callback()。要以性能为代价检查 Tracer 泄漏,请与jax.check_tracer_leaks()一起使用)。

在跟踪时,JAX 会用tracer对象包装每个参数。然后,这些 tracer 会记录在函数调用期间(在常规 Python 中发生)对它们执行的所有 JAX 操作。然后,JAX 使用 tracer 记录来重建整个函数。该重建的输出是 jaxpr。由于 tracer 不记录 Python 副作用,因此它们不会出现在 jaxpr 中。但是,副作用仍然在跟踪本身期间发生。

注意:Python print()函数不纯:文本输出是函数的副作用。因此,任何print()调用都只会发生在跟踪期间,并且不会出现在 jaxpr 中。

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

看到打印的x是一个Traced对象了吗?那是 JAX 内部机制在起作用。

Python 代码至少运行一次这一事实严格来说是实现细节,因此不应依赖它。但是,了解它很有用,因为您可以在调试时使用它来打印计算的中间值。

需要理解的一件关键事情是,jaxpr 会捕获在给定参数上执行的函数。例如,如果我们有一个 Python 条件语句,则 jaxpr 仅了解我们采取的分支。

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

JIT 编译函数#

如前所述,JAX 允许使用相同的代码在 CPU/GPU/TPU 上执行操作。让我们来看一个计算缩放指数线性单元SELU)的示例,这是一种深度学习中常用的操作。

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
4.31 ms ± 76.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

上面的代码一次将一个操作发送到加速器。这限制了 XLA 编译器优化我们的函数的能力。

自然地,我们想要做的是尽可能多地提供代码给 XLA 编译器,以便它能够充分优化它。为此,JAX 提供了jax.jit()变换,它将 JIT 编译 JAX 兼容函数。下面的示例显示了如何使用 JIT 加速之前的函数。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
1.61 ms ± 3.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

以下是刚刚发生的事情

  1. 我们将selu_jit定义为selu的编译版本。

  2. 我们对x调用了selu_jit一次。这是 JAX 进行跟踪的地方——毕竟它需要一些输入来包装在 tracer 中。然后使用 XLA 将 jaxpr 编译成针对您的 GPU 或 TPU 优化的非常高效的代码。最后,执行编译后的代码以满足调用。随后对selu_jit的调用将直接使用编译后的代码,完全跳过 Python 实现。(如果我们没有单独包含预热调用,一切仍然可以工作,但编译时间将包含在基准测试中。它仍然会更快,因为我们在基准测试中运行了许多循环,但这不是一个公平的比较)。

  3. 我们对编译版本的执行速度进行了计时。(注意block_until_ready()的使用,这是由于 JAX 的异步调度所必需的)。

为什么不能只 JIT 所有内容?#

在完成上述示例后,您可能想知道我们是否应该简单地将jax.jit()应用于每个函数。为了理解为什么不是这样,以及何时应该/不应该应用jit,让我们首先检查一些 JIT 不起作用的情况。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1268/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1268/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

这两种情况下出现的问题是我们试图使用运行时值来控制程序的跟踪时间流。JIT 内部的跟踪值,如这里的xn,只能通过它们的静态属性来影响控制流:例如shapedtype,而不是通过它们的值。有关 Python 控制流与 JAX 之间交互的更多详细信息,请参阅🔪 JAX - The Sharp Bits 🔪: 控制流

解决此问题的一种方法是重写代码以避免对值进行条件判断。另一种方法是使用特殊的控制流运算符,如jax.lax.cond()。但是,有时这不可能或不切实际。在这种情况下,您可以考虑仅 JIT 编译函数的一部分。例如,如果函数中最耗时的部分在循环内部,我们可以只 JIT 编译该内部部分(但请确保查看下一节关于缓存的内容,以避免自讨苦吃)。

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

将参数标记为静态#

如果我们确实需要 JIT 编译一个对输入值的条件进行判断的函数,我们可以告诉 JAX 为特定输入使用不太抽象的 tracer,方法是指定static_argnumsstatic_argnames。这样做会导致生成的 jaxpr 和编译后的工件依赖于传递的特定值,因此 JAX 必须为指定静态输入的每个新值重新编译函数。只有在函数保证看到有限的静态值集时,这才是好的策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

当使用jit作为装饰器时,指定此类参数的常见模式是使用 Python 的functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT 和缓存#

有了第一次 JIT 调用的编译开销,了解jax.jit()如何以及何时缓存以前的编译对于有效地使用它至关重要。

假设我们定义f = jax.jit(g)。当我们第一次调用f时,它将被编译,并且生成的 XLA 代码将被缓存。随后对f的调用将重用缓存的代码。这就是jax.jit弥补编译前期成本的方式。

如果我们指定static_argnums,则仅当参数的值与标记为静态的参数的值相同,缓存的代码才会被使用。如果其中任何一个发生更改,则会重新编译。如果值很多,那么您的程序可能花费更多时间编译而不是逐个执行操作。

避免在循环或其他 Python 范围内定义的临时函数上调用jax.jit()。对于大多数情况,JAX 将能够在后续对jax.jit()的调用中使用编译后的缓存函数。但是,由于缓存依赖于函数的哈希值,因此当重新定义等效函数时,就会出现问题。这将导致每次循环中不必要的编译。

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
354 ms ± 5.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
356 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
5.21 ms ± 9.25 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)