如何用 JAX 思考#

Open in Colab Open in Kaggle

JAX 提供了一个简单而强大的 API,用于编写加速的数值代码,但在 JAX 中高效工作有时需要额外的考虑。本文档旨在帮助建立对 JAX 如何运作的由浅入深的理解,以便您可以更有效地使用它。

JAX 与 NumPy#

核心概念

  • JAX 为了方便起见,提供了一个受 NumPy 启发的接口。

  • 通过鸭子类型,JAX 数组通常可以用作 NumPy 数组的直接替换。

  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了一个众所周知的、强大的 API,用于处理数值数据。为了方便起见,JAX 提供了 jax.numpy,它与 numpy API 非常相似,并为 JAX 提供了简单的入口。几乎任何可以用 numpy 完成的操作都可以用 jax.numpy 完成

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
../_images/b2db475a8afa1d2e364a801f61f7b347b75a355e9da0be2f015a2d1aefdea45c.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/487cfe9c47318bd2e5849cf09dc8048af87a3364e9f0e0e524de8e950911888e.png

代码块除了将 np 替换为 jnp 之外是相同的,并且结果也相同。正如我们所看到的,JAX 数组通常可以直接代替 NumPy 数组用于绘图等操作。

数组本身以不同的 Python 类型实现

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python 的 鸭子类型 允许 JAX 数组和 NumPy 数组在许多地方可以互换使用。

然而,JAX 数组和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,这意味着一旦创建,它们的内容就不能更改。

以下是在 NumPy 中修改数组的示例

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

在 JAX 中,等效操作会导致错误,因为 JAX 数组是不可变的

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.ac.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html

对于更新单个元素,JAX 提供了一个 索引更新语法,它会返回更新后的副本

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy、lax & XLA:JAX API 分层#

核心概念

  • jax.numpy 是一个高级包装器,提供了一个熟悉的接口。

  • jax.lax 是一个更底层的 API,它更严格,通常更强大。

  • 所有 JAX 操作都是根据 XLA(加速线性代数编译器)中的操作实现的。

如果您查看 jax.numpy 的源代码,您会看到所有操作最终都以 jax.lax 中定义的函数来表示。您可以将 jax.lax 看作是用于处理多维数组的更严格但通常更强大的 API。

例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax 不会

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

如果直接使用 jax.lax,则在这种情况下必须显式执行类型提升

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了这种严格性之外,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。

例如,考虑一个一维卷积,它可以用 NumPy 这样表示

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在底层,此 NumPy 操作被转换为由 lax.conv_general_dilated 实现的更通用的卷积

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批量卷积操作,旨在对深度神经网络中常用的卷积类型高效执行。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多详细信息,请参阅 JAX 中的卷积)。

本质上,所有 jax.lax 操作都是 XLA 中操作的 Python 包装器;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都以这些基本的 XLA 操作来表示,这使得即时 (JIT) 编译成为可能。

是否使用 JIT#

核心概念

  • 默认情况下,JAX 一次按顺序执行一个操作。

  • 使用即时 (JIT) 编译装饰器,可以一起优化一系列操作并一次运行。

  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态的且已知的。

所有 JAX 操作都以 XLA 来表示的事实使得 JAX 可以使用 XLA 编译器非常高效地执行代码块。

例如,考虑这个函数,它对一个二维矩阵的行进行归一化,用 jax.numpy 操作表示

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换创建该函数的即时编译版本

from jax import jit
norm_compiled = jit(norm)

该函数返回与原始函数相同的结果,直至标准的浮点精度

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但是由于编译(包括操作的融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可以快几个数量级(请注意使用 block_until_ready() 来解决 JAX 的 异步调度

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
358 μs ± 5.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
257 μs ± 2.13 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

也就是说,jax.jit 确实有局限性:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐个操作模式下执行

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但是,如果您尝试在 jit 模式下执行它,则会返回错误

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

这是因为该函数生成一个其形状在编译时未知的数组:输出的大小取决于输入数组的值,因此它与 JIT 不兼容。

JIT 机制:跟踪和静态变量#

核心概念

  • JIT 和其他 JAX 转换通过跟踪一个函数来确定其对特定形状和类型的输入的影响。

  • 您不想被跟踪的变量可以标记为静态

为了有效地使用 jax.jit,了解它的工作原理很有用。让我们在 JIT 编译的函数中放置几个 print() 语句,然后调用该函数

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

请注意,print 语句会执行,但它不是打印我们传递给函数的数据,而是打印代替它们的tracer对象。

这些 tracer 对象是 jax.jit 用来提取函数指定的运算序列的对象。基本 tracer 是代替数组的 形状dtype 的对象,但与值无关。然后,可以在 XLA 中将此记录的计算序列高效地应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们再次在匹配的输入上调用编译后的函数时,不需要重新编译,并且不会打印任何内容,因为结果是在编译后的 XLA 中而不是在 Python 中计算的

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

提取的操作序列在 JAX 表达式(简称 jaxpr)中编码。您可以使用 jax.make_jaxpr 转换来查看 jaxpr

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

请注意这一点的一个后果:由于 JIT 编译是在没有数组内容信息的情况下完成的,因此函数中的控制流语句不能依赖于跟踪值。例如,这会失败

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2619/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果有您不想跟踪的变量,可以将它们标记为静态以进行 JIT 编译

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此该函数仍然按预期工作

f(1, False)
Array(1, dtype=int32, weak_type=True)

了解哪些值和操作将是静态的,哪些将被跟踪是有效使用 jax.jit 的关键部分。

静态操作与跟踪操作#

核心概念

  • 正如值可以是静态的或被跟踪的一样,操作也可以是静态的或被跟踪的。

  • 静态操作在 Python 中的编译时进行求值;跟踪操作在 XLA 中的运行时进行编译和求值。

  • 对于您想要静态执行的操作,请使用 numpy;对于您想要跟踪的操作,请使用 jax.numpy

静态值和跟踪值之间的这种区别使得考虑如何保持静态值的静态性变得很重要。考虑这个函数

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2619/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2619/1983583872.py:6 (f)

这会失败,并显示一条错误,指出找到了一个 tracer,而不是整数类型的具体值的一维序列。让我们向该函数添加一些 print 语句,以了解为什么会发生这种情况

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>

请注意,虽然 x 被跟踪,但 x.shape 是一个静态值。但是,当我们在该静态值上使用 jnp.arrayjnp.prod 时,它会变成一个被跟踪的值,此时它不能在像 reshape() 这样的需要静态输入的函数中使用(回想一下:数组形状必须是静态的)。

一个有用的模式是对于应该静态执行(即在编译时完成)的操作,使用 numpy,而对于应该跟踪(即在运行时编译和执行)的操作,使用 jax.numpy。对于这个函数,它可能看起来像这样

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

出于这个原因,JAX 程序中的一个标准约定是 import numpy as npimport jax.numpy as jnp,以便两个接口都可用于更精细地控制操作是以静态方式(使用 numpy,在编译时一次)还是以跟踪方式(使用 jax.numpy,在运行时优化)执行。