快速入门#
JAX 是一个用于面向数组的数值计算的库(类似于 NumPy),它具有自动微分和 JIT 编译功能,可以实现高性能机器学习研究.
本文档简要概述了 JAX 的基本功能,以便您快速入门 JAX
JAX 提供了一个统一的 NumPy 式界面,用于在本地或分布式设置中,在 CPU、GPU 或 TPU 上运行的计算。
JAX 通过 Open XLA 提供了内置的即时 (JIT) 编译功能,Open XLA 是一个开源的机器学习编译器生态系统。
JAX 函数支持通过其自动微分转换来高效评估梯度。
JAX 函数可以自动矢量化,以便高效地将它们映射到表示输入批次的数组上。
安装#
JAX 可以从 Python 包索引 直接安装到 Linux、Windows 和 macOS 上的 CPU
pip install jax
或者,对于 NVIDIA GPU
pip install -U "jax[cuda12]"
有关更详细的特定于平台的安装信息,请查看 安装.
JAX 作为 NumPy#
大多数 JAX 使用都通过熟悉的 jax.numpy
API 进行,该 API 通常在 jnp
别名下导入
import jax.numpy as jnp
通过此导入,您可以立即以类似于典型 NumPy 程序的方式使用 JAX,包括使用 NumPy 式数组创建函数、Python 函数和运算符,以及数组属性和方法
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
[0. 1.05 2.1 3.1499999 4.2 ]
当您开始深入研究时,您会发现 JAX 数组和 NumPy 数组之间存在一些差异;这些将在 🔪 JAX - The Sharp Bits 🔪 中探讨。
使用 jax.jit()
# 的即时编译
JAX 在 GPU 或 TPU 上透明运行(如果您没有,则回退到 CPU)。但是,在上面的示例中,JAX 正在一次一个操作地将内核调度到芯片上。如果我们有一系列操作,我们可以使用 jax.jit()
函数使用 XLA 将这系列操作一起编译。
我们可以使用 IPython 的 %timeit
来快速基准测试我们的 selu
函数,使用 block_until_ready()
来解释 JAX 的动态调度(参见 Asynchronous dispatch)
from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
4.09 ms ± 18.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(注意我们使用了 jax.random
来生成一些随机数;有关如何在 JAX 中生成随机数的详细信息,请查看 Pseudorandom numbers)。
我们可以使用 jax.jit()
变换来加速此函数的执行,该变换将在首次调用 selu
时进行即时编译,并将在之后进行缓存。
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()
1.33 ms ± 4.41 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
上面的计时代表在 CPU 上执行,但相同的代码可以在 GPU 或 TPU 上运行,通常可以获得更快的速度。
有关 JAX 中 JIT 编译的更多信息,请查看 Just-in-time compilation。
使用 jax.grad()
# 求导数
除了通过 JIT 编译来转换函数之外,JAX 还提供其他转换。其中一项转换是 jax.grad()
,它执行 自动微分 (autodiff)
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
让我们用有限差分来验证我们的结果是否正确。
def first_finite_differences(f, x, eps=1E-3):
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761 0.10502338]
grad()
和 jit()
变换可以组合,并且可以任意混合。在上面的示例中,我们对 sum_logistic
进行了即时编译,然后求了它的导数。我们可以更进一步
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256
除了标量值函数之外,jax.jacobian()
变换可用于计算向量值函数的完整雅可比矩阵
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1. 0. 0. ]
[0. 2.7182817 0. ]
[0. 0. 7.389056 ]]
对于更高级的 autodiff 操作,您可以使用 jax.vjp()
进行反向模式向量-雅可比乘积,以及 jax.jvp()
和 jax.linearize()
进行前向模式雅可比-向量乘积。这两种模式可以任意相互组合,也可以与其他 JAX 变换组合。例如,jax.jvp()
和 jax.vjp()
用于定义前向模式 jax.jacfwd()
和反向模式 jax.jacrev()
,分别用于在前向和反向模式下计算雅可比矩阵。以下是一种组合它们以创建有效计算完整 Hessian 矩阵的函数的方法
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0. -0. -0. ]
[-0. -0.09085776 -0. ]
[-0. -0. -0.07996249]]
这种组合在实践中会生成有效的代码;这几乎是 JAX 的内置 jax.hessian()
函数的实现方式。
有关 JAX 中自动微分的更多信息,请查看 Automatic differentiation。
使用 jax.vmap()
# 的自动向量化
另一个有用的转换是 vmap()
,即向量化映射。它具有沿数组轴映射函数的熟悉语义,但它不是显式循环遍历函数调用,而是将函数转换为本机向量化版本,以提高性能。当与 jit()
组合时,它可以与手动重写函数以在额外的批处理维度上操作一样有效。
我们将使用一个简单的示例,并使用 vmap()
将矩阵-向量乘积提升为矩阵-矩阵乘积。虽然在这种特殊情况下,手动操作很容易,但相同的技术可以应用于更复杂的函数。
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
apply_matrix
函数将向量映射到向量,但我们可能希望将其逐行应用于矩阵。我们可以通过在 Python 中循环遍历批处理维度来做到这一点,但这通常会导致性能低下。
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
2.93 ms ± 44.4 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
熟悉 jnp.dot
函数的程序员可能会意识到 apply_matrix
可以重写为避免显式循环,使用 jnp.dot
的内置批处理语义
import numpy as np
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
86.4 μs ± 3.86 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
但是,随着函数变得越来越复杂,这种手动批处理变得越来越困难,也更容易出错。vmap()
变换旨在自动将函数转换为批处理感知版本
from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
91.2 μs ± 2.85 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
正如您所预期的那样,vmap()
可以任意与 jit()
、grad()
和任何其他 JAX 变换组合。
有关 JAX 中自动向量化的更多信息,请查看 Automatic vectorization。
这只是 JAX 可以做的事情的简要概述。我们非常期待看到您用它做什么!