自动微分#
在本节中,您将了解 JAX 中自动微分(自动微分)的基本应用。JAX 具有非常通用的自动微分系统。计算梯度是现代机器学习方法的关键部分,本教程将引导您完成一些自动微分的入门主题,例如
确保也查看 高级自动微分 教程以了解更高级的主题。
虽然了解自动微分“幕后”的工作原理对于在大多数情况下使用 JAX 并不重要,但鼓励您查看这个非常易懂的 视频 以更深入地了解正在发生的事情。
1. 使用 jax.grad
获取梯度#
在 JAX 中,可以使用 jax.grad()
转换来微分标量值函数
import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
jax.grad()
接受一个函数并返回一个函数。如果你有一个 Python 函数 f
,它用来计算数学函数 \(f\),那么 jax.grad(f)
就是一个 Python 函数,它用来计算数学函数 \(\nabla f\)。这意味着 grad(f)(x)
代表了值 \(\nabla f(x)\).
由于 jax.grad()
是作用于函数的,你可以把它应用到自己的输出上,这样就可以进行任意次数的微分。
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
JAX 的自动微分使得计算高阶导数变得很容易,因为计算导数的函数本身也是可微的。因此,高阶导数就像堆叠变换一样简单。这可以在单变量情况下说明。
\(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数可以计算为
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
\(f\) 的高阶导数为
在 JAX 中计算这些导数就像将 jax.grad()
函数链接起来一样简单。
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
在 \(x=1\) 处对以上表达式进行求值会得到
使用 JAX
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0
2. 在线性逻辑回归中计算梯度#
下一个例子展示了如何在线性逻辑回归模型中使用 jax.grad()
来计算梯度。首先,设置。
key = jax.random.key(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())
使用 jax.grad()
函数及其 argnums
参数来对函数相对于位置参数进行微分。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
jax.grad()
API 与 Spivak 的经典著作 Calculus on Manifolds (1965) 中的优秀符号完全对应,该符号也用于 Sussman 和 Wisdom 的 Structure and Interpretation of Classical Mechanics (2015) 和他们的 Functional Differential Geometry (2013)。这两本书都是开放获取的。特别是查看 Functional Differential Geometry 的“前言”部分,了解对这种符号的辩护。
本质上,当使用 argnums
参数时,如果 f
是用来计算数学函数 \(f\) 的 Python 函数,那么 Python 表达式 jax.grad(f, i)
计算结果是用来计算 \(\partial_i f\) 的 Python 函数。
3. 相对于嵌套列表、元组和字典进行微分#
由于 JAX 的 PyTree 抽象(参见 Working with pytrees),相对于标准 Python 容器进行微分就可以正常工作,因此可以按照你喜欢的方式使用元组、列表和字典(以及任意嵌套)。
继续上一个例子
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
你可以创建 Custom pytree nodes 来处理不仅 jax.grad()
而且其他 JAX 变换(jax.jit()
、jax.vmap()
等等)。
4. 使用 jax.value_and_grad
来计算函数及其梯度#
另一个方便的函数是 jax.value_and_grad()
,它可以有效地在一步内计算函数的值及其梯度的值。
继续上一个例子
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
5. 与数值差值进行比较#
导数的妙处在于可以用有限差分方法直观地进行验证。
继续上一个例子
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX 提供了一个简单的便捷函数,它本质上做了相同的事情,但可以检查你喜欢的任意阶导数。
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
下一步#
高级自动微分 教程提供了关于本节内容是如何在 JAX 后端实现的更高级和更详细的解释。一些功能,例如 JAX 可转换 Python 函数的自定义导数规则,依赖于对高级自动微分的理解,所以如果你感兴趣,请查看 高级自动微分 教程中的这一节。