自动微分#
在本节中,您将了解 JAX 中自动微分(autodiff)的基本应用。JAX 具有相当通用的自动微分系统。计算梯度是现代机器学习方法的关键部分,本教程将引导您了解一些入门级的自动微分主题,例如
请务必查看高级自动微分教程,了解更多高级主题。
虽然理解自动微分“底层”的工作原理对于在大多数情况下使用 JAX 来说并非至关重要,但我们还是鼓励您观看这个非常容易理解的视频,以便更深入地了解其工作原理。
1. 使用 jax.grad
求梯度#
在 JAX 中,您可以使用 jax.grad()
变换对标量值函数进行微分
import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
jax.grad()
接受一个函数并返回一个函数。如果您有一个 Python 函数 f
来计算数学函数 \(f\),那么 jax.grad(f)
就是一个 Python 函数,用于计算数学函数 \(\nabla f\)。这意味着 grad(f)(x)
表示值 \(\nabla f(x)\)。
由于 jax.grad()
对函数进行操作,因此您可以将其应用于自身的输出,从而进行任意次数的微分
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
JAX 的自动微分功能可以轻松计算高阶导数,因为计算导数的函数本身是可微的。因此,高阶导数与堆叠变换一样简单。这可以在单变量情况下进行说明
函数 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数可以计算为
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
\(f\) 的高阶导数为
在 JAX 中计算这些导数中的任何一个都像链式调用 jax.grad()
函数一样简单
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
在 \(x=1\) 处评估上述表达式将得到
使用 JAX:
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0
2. 计算线性逻辑回归中的梯度#
下一个示例演示如何在线性逻辑回归模型中使用 jax.grad()
计算梯度。首先,进行设置:
key = jax.random.key(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())
使用 jax.grad()
函数及其 argnums
参数,对函数关于位置参数进行微分。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
jax.grad()
API 与 Spivak 的经典著作 流形上的微积分 (1965) 中的出色符号表示法直接对应,该符号表示法也用于 Sussman 和 Wisdom 的经典力学的结构与解释 (2015) 和他们的 函数微分几何 (2013)。这两本书都是开放获取的。有关此符号表示的辩护,请特别参阅函数微分几何的“序言”部分。
本质上,当使用 argnums
参数时,如果 f
是一个用于计算数学函数 \(f\) 的 Python 函数,那么 Python 表达式 jax.grad(f, i)
将计算一个用于计算 \(\partial_i f\) 的 Python 函数。
3. 对嵌套列表、元组和字典进行微分#
由于 JAX 的 PyTree 抽象(请参阅 使用 pytrees),对标准 Python 容器进行微分是直接有效的,因此您可以随意使用元组、列表和字典(以及任意嵌套)。
继续前面的例子:
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
您可以创建自定义 pytree 节点,不仅可以与jax.grad()
一起使用,还可以与其他 JAX 转换(jax.jit()
、jax.vmap()
等)一起使用。
4. 使用 jax.value_and_grad
评估函数及其梯度#
另一个方便的函数是 jax.value_and_grad()
,用于高效地一次性计算函数的值及其梯度值。
继续前面的示例:
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
5. 与数值差异进行比较#
关于导数的一大优点是,它们可以使用有限差分进行直接比较。
继续前面的示例:
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX 提供了一个简单的便捷函数,该函数本质上执行相同的操作,但会检查您喜欢的任何阶数的微分
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
下一步#
高级自动微分教程提供了关于本文档中涵盖的思想如何在 JAX 后端中实现的更高级和详细的解释。诸如JAX 可转换 Python 函数的自定义导数规则之类的某些功能依赖于对高级自动微分的理解,因此如果您有兴趣,请务必查看高级自动微分教程中的该部分。