快速入门#

JAX 是一个用于面向数组的数值计算的库(à la NumPy),具有自动微分和 JIT 编译功能,以支持高性能机器学习研究.

本文档快速概述了 JAX 的基本功能,以便您可以快速开始使用 JAX

  • JAX 为在本地或分布式环境中,在 CPU、GPU 或 TPU 上运行的计算提供了一个统一的类似 NumPy 的接口。

  • JAX 具有通过 Open XLA(一个开源机器学习编译器生态系统)构建的内置即时 (JIT) 编译功能。

  • JAX 函数通过其自动微分转换支持梯度的有效评估。

  • JAX 函数可以自动向量化,以便有效地将它们映射到表示输入批次的数组上。

安装#

JAX 可以直接从 Python 包索引 在 Linux、Windows 和 macOS 上为 CPU 安装

pip install jax

或者,对于 NVIDIA GPU

pip install -U "jax[cuda12]"

有关更详细的平台特定安装信息,请查看安装

JAX 作为 NumPy#

大多数 JAX 用法是通过熟悉的 jax.numpy API,它通常以 jnp 别名导入

import jax.numpy as jnp

通过此导入,您可以立即以类似于典型 NumPy 程序的方式使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和运算符以及数组属性和方法

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))
[0.        1.05      2.1       3.1499999 4.2      ]

一旦您开始深入研究,您会发现 JAX 数组和 NumPy 数组之间存在一些差异;这些差异在 🔪 JAX - 尖锐部分 🔪 中进行了探讨。

使用 jax.jit() 的即时编译#

JAX 在 GPU 或 TPU 上透明运行(如果您没有,则回退到 CPU)。但是,在上面的示例中,JAX 正在一次向芯片调度一个操作的内核。如果我们有一系列操作,我们可以使用 jax.jit() 函数将这一系列操作一起使用 XLA 编译。

我们可以使用 IPython 的 %timeit 快速基准测试我们的 selu 函数,使用 block_until_ready() 来解释 JAX 的动态调度(请参阅异步调度

from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
751 μs ± 24.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

(请注意,我们使用了 jax.random 来生成一些随机数;有关如何在 JAX 中生成随机数的详细信息,请查看伪随机数)。

我们可以使用 jax.jit() 转换来加速此函数的执行,这将在第一次调用 selu 时进行 jit 编译,并在之后缓存。

from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()
245 μs ± 5.33 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

上面的计时表示在 CPU 上的执行,但相同的代码可以在 GPU 或 TPU 上运行,通常是为了更大的加速。

有关 JAX 中 JIT 编译的更多信息,请查看即时编译

使用 jax.grad() 求导#

除了通过 JIT 编译转换函数外,JAX 还提供了其他转换。其中一种转换是 jax.grad(),它执行自动微分 (autodiff)

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

让我们用有限差分验证我们的结果是正确的。

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569  0.10502338]

grad()jit() 转换可以组合并可以任意混合。在上面的示例中,我们 jitted 了 sum_logistic,然后求了它的导数。我们可以更进一步

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

除了标量值函数之外,jax.jacobian() 转换可用于计算向量值函数的完整 Jacobian 矩阵

from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]

对于更高级的自动微分操作,您可以使用 jax.vjp() 进行反向模式向量-雅可比积,以及 jax.jvp()jax.linearize() 用于前向模式雅可比-向量积。这两者可以彼此任意组合,并且可以与其他 JAX 转换组合。例如,jax.jvp()jax.vjp() 用于定义前向模式 jax.jacfwd() 和反向模式 jax.jacrev(),分别用于在前向和反向模式下计算 Jacobian 矩阵。这是一种组合它们的方法,可以创建一个有效计算完整 Hessian 矩阵的函数

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]

这种组合在实践中产生高效的代码;这或多或少是 JAX 的内置 jax.hessian() 函数的实现方式。

有关 JAX 中自动微分的更多信息,请查看自动微分

使用 jax.vmap() 自动向量化#

另一个有用的转换是 vmap(),向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是显式循环函数调用,它将函数转换为本机向量化版本,以获得更好的性能。当与 jit() 组合使用时,它可以像手动重写函数以在额外的批次维度上操作一样高效。

我们将使用一个简单的示例,并使用 vmap() 将矩阵-向量积提升为矩阵-矩阵积。虽然在这种特定情况下很容易手动完成,但相同的技术可以应用于更复杂的函数。

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

apply_matrix 函数将向量映射到向量,但我们可能希望在矩阵中按行应用它。我们可以通过在 Python 中循环批次维度来做到这一点,但这通常会导致性能不佳。

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
443 μs ± 7.53 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

熟悉 jnp.dot 函数的程序员可能会认识到,可以重写 apply_matrix 以避免显式循环,使用 jnp.dot 的内置批处理语义

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
12.1 μs ± 84.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

然而,随着函数变得越来越复杂,这种手动批处理变得更加困难且容易出错。vmap() 转换旨在自动将函数转换为批处理感知版本

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
16.2 μs ± 92.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

正如您所期望的那样,vmap() 可以与 jit()grad() 和任何其他 JAX 转换任意组合。

有关 JAX 中自动向量化的更多信息,请查看自动向量化

这只是 JAX 功能的冰山一角。我们非常期待看到您如何使用它!