自动微分食谱#

Open in Colab Open in Kaggle

JAX 具有一个非常通用的自动微分系统。在本笔记本中,我们将介绍大量有用的自动微分概念,您可以根据自己的工作选择使用,从基础开始。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

梯度#

grad 开始#

您可以使用 grad 对函数进行微分

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

grad 接受一个函数并返回一个函数。如果您有一个评估数学函数 \(f\) 的 Python 函数 f,则 grad(f) 是一个评估数学函数 \(\nabla f\) 的 Python 函数。这意味着 grad(f)(x) 表示值 \(\nabla f(x)\).

由于 grad 操作的是函数,您可以将其应用于自己的输出,以根据需要进行多次微分

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

让我们看看如何在线性逻辑回归模型中使用 grad 计算梯度。首先,进行设置

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

使用 grad 函数及其 argnums 参数,针对位置参数对函数进行微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245

grad API 与 Spivak 的经典著作 《微分流形上的微积分》(1965 年)中出色的符号表示直接对应,该符号表示也应用于 Sussman 和 Wisdom 的 《经典力学的结构和解释》(2015 年)和他们的 《函数微分几何》(2013 年)。这两本书都是开放获取的。特别是,请查看《函数微分几何》的“前言”部分,了解对该符号表示的辩护。

本质上,使用 argnums 参数时,如果 f 是用于评估数学函数 \(f\) 的 Python 函数,那么 Python 表达式 grad(f, i) 会评估为用于评估 \(\partial_i f\) 的 Python 函数。

针对嵌套列表、元组和字典进行微分#

针对标准 Python 容器进行微分就可以了,因此请根据需要使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}

您可以 注册您自己的容器类型,使其不仅能够与 grad 协同工作,还能与所有 JAX 变换(jitvmap 等)协同工作。

使用 value_and_grad 评估函数及其梯度#

另一个便捷的函数是 value_and_grad,它可以高效地同时计算函数的值和梯度值。

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385

针对数值差进行检查#

微分的优点在于,可以用有限差分对它们进行直接检查。

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117

JAX 提供了一个简单的便捷函数,其功能与之基本相同,但可以检查您喜欢的任意阶次微分。

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

使用 gradgrad 创建 Hessian-向量积#

我们可以用更高阶的 grad 来构建 Hessian-向量积函数。(稍后,我们将编写一个更有效的实现,它混合使用正向模式和反向模式,但此实现将使用纯反向模式。)

Hessian-向量积函数在 截断牛顿共轭梯度算法 中很有用,该算法可用于最小化平滑凸函数,或用于研究神经网络训练目标的曲率(例如 1234)。

对于一个标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\),如果它具有连续的二阶导数(因此 Hessian 矩阵是对称的),则该函数在点 \(x \in \mathbb{R}^n\) 处的 Hessian 写作 \(\partial^2 f(x)\)。然后,Hessian-向量积函数可以评估

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

对于任何 \(v \in \mathbb{R}^n\)

关键在于不要实例化完整的 Hessian 矩阵:如果 \(n\) 很大,例如在神经网络的上下文中可能达到数百万或数十亿,那么这可能无法存储。

幸运的是,grad 已经为我们提供了一种编写高效 Hessian-向量积函数的方法。我们只需要使用以下恒等式即可

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),

其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\)\(x\) 处的梯度与向量 \(v\) 进行点积。请注意,我们只对向量值参数的标量值函数进行微分,而这正是我们知道 grad 效率很高的领域。

在 JAX 代码中,我们可以直接编写以下代码

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

此示例表明,您可以随意使用词法闭包,JAX 永远不会受到干扰或混乱。

我们将在接下来的几个单元格中检查此实现,届时我们将看到如何计算密集的 Hessian 矩阵。我们还将编写一个更好的版本,它使用正向模式和反向模式。

使用 jacfwdjacrev 计算雅可比行列式和 Hessian#

您可以使用 jacfwdjacrev 函数计算完整的雅可比行列式矩阵。

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]

这两个函数计算相同的值(在机器数值上有所不同),但实现方式不同:jacfwd 使用正向模式自动微分,这对于“高”雅可比行列式矩阵(输出比输入多)更有效,而 jacrev 使用反向模式,这对于“宽”雅可比行列式矩阵(输入比输出多)更有效。对于接近方阵的矩阵,jacfwd 可能比 jacrev 更有效。

您也可以将 jacfwdjacrev 与容器类型一起使用。

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]

有关正向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用这两个函数的组合可以计算密集的 Hessian 矩阵。

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]

此形状有意义:如果我们从一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们预计会得到以下形状

  • \(f(x) \in \mathbb{R}^m\)\(f\)\(x\) 处的函数值

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\)\(x\) 处的雅可比行列式矩阵

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\)\(x\) 处的 Hessian

等等。

为了实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或这两个函数的任何其他组合。但正向模式覆盖反向模式通常是最有效的。这是因为在内部雅可比行列式计算中,我们通常会微分一个具有宽雅可比行列式矩阵的函数(例如,损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部雅可比行列式计算中,我们则会微分一个具有方阵雅可比行列式矩阵的函数(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这正是正向模式胜出的地方。

实现原理:两个基础自动微分函数#

雅可比行列式-向量积(JVP,也称为正向模式自动微分)#

JAX 包含正向模式和反向模式自动微分的有效且通用的实现。熟悉的 grad 函数构建在反向模式之上,但为了解释两种模式的差异以及何时可以使用每种模式,我们需要一些数学背景。

数学中的 JVP#

在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)\(f\) 在输入点 \(x \in \mathbb{R}^n\) 处计算得出的雅可比行列式,记为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的矩阵。

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).

但我们也可以将 \(\partial f(x)\) 视为一个线性映射,该映射将 \(f\) 域在点 \(x\) 处的切空间(它只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 陪域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)。

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).

此映射称为 \(f\)\(x\) 处的 前推映射。雅可比行列式矩阵只是此线性映射在标准基中的矩阵。

如果我们不确定某个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为首先获取一个输入点,然后返回该输入点处的雅可比行列式线性映射。

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).

特别是,我们可以取消对函数进行柯里化,以便给定输入点 \(x \in \mathbb{R}^n\) 和一个切向量 \(v \in \mathbb{R}^n\),我们可以获得 \(\mathbb{R}^m\) 中的输出切向量。我们将此映射称为从 \((x, v)\) 对到输出切向量的映射,并将其记为 雅可比行列式-向量积,写作

\(\qquad (x, v) \mapsto \partial f(x) v\)

JAX 代码中的 JVP#

回到 Python 代码中,JAX 的 jvp 函数模拟了这种转换。给定一个用于评估 \(f\) 的 Python 函数,JAX 的 jvp 是一种获取用于评估 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

类似 Haskell 的类型签名 而言,我们可以写成

jvp :: (a -> b) -> a -> T a -> (b, T b)

其中我们使用 T a 来表示 a 的切线空间类型。换句话说,jvp 以类型为 a -> b 的函数、类型为 a 的值和类型为 T a 的切线向量值作为参数。它返回一对,包含一个类型为 b 的值和一个类型为 T b 的输出切线向量。

经过 jvp 转换的函数的评估方式与原始函数非常相似,但它会将类型为 T a 的切线值与类型为 a 的每个原始值配对。对于原始函数将应用的每个基本数值运算,经过 jvp 转换的函数将执行该基本运算的“JVP 规则”,该规则既会在原始值上评估基本运算,又在这些原始值上应用基本运算的 JVP。

这种评估策略对计算复杂度有一些直接的影响:由于我们在评估过程中评估 JVP,因此我们无需存储任何内容以供日后使用,因此内存成本与计算的深度无关。此外,经过 jvp 转换的函数的 FLOP 成本大约是仅评估函数成本的 3 倍(评估原始函数的工作量为 1 个单位,例如 sin(x);线性化的工作量为 1 个单位,例如 cos(x);将线性化函数应用于向量的单位工作量为 1 个单位,例如 cos_x * v)。换句话说,对于固定的原始点 \(x\),我们可以评估 \(v \mapsto \partial f(x) \cdot v\),其边际成本与评估 \(f\) 的边际成本大致相同。

这种内存复杂度听起来相当有说服力!那么为什么我们很少在机器学习中看到前向模式呢?

为了回答这个问题,首先考虑一下如何使用 JVP 构建完整的雅可比矩阵。如果我们将 JVP 应用于独热切线向量,它将揭示雅可比矩阵的一列,对应于我们输入的非零项。因此,我们可以一次构建一个完整的雅可比矩阵,并且获取每列的成本与一次函数评估的成本大致相同。对于具有“高”雅可比矩阵的函数来说,这将很有效率,但对于具有“宽”雅可比矩阵的函数来说,这将效率低下。

如果您在机器学习中进行基于梯度的优化,您可能希望将从 \(\mathbb{R}^n\) 中的参数到 \(\mathbb{R}\) 中的标量损失值的损失函数最小化。这意味着此函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其识别为梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。一次构建一个矩阵,每次调用所需的 FLOP 量与评估原始函数所需的 FLOP 量相似,这似乎效率低下!特别是对于训练神经网络而言,其中 \(f\) 是训练损失函数,\(n\) 可以达到数百万甚至数十亿,这种方法根本无法扩展。

为了更好地处理此类函数,我们只需要使用反向模式。

向量-雅可比积(VJP,也称为反向模式自动微分)#

前向模式为我们返回一个用于评估雅可比-向量积的函数,然后我们可以使用它一次构建一个雅可比矩阵;反向模式是一种获取用于评估向量-雅可比积(等效于雅可比转置-向量积)的函数的方法,我们可以使用它一次构建一个雅可比矩阵。

数学中的 VJP#

我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们对 JVP 的表示法开始,VJP 的表示法非常简单

\(\qquad (x, v) \mapsto v \partial f(x)\),

其中 \(v\)\(f\)\(x\) 处的余切空间的元素(与 \(\mathbb{R}^m\) 的另一个副本同构)。在严谨的情况下,我们应该将 \(v\) 视为一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),并且当我们写 \(v \partial f(x)\) 时,我们指的是函数组合 \(v \circ \partial f(x)\),其中类型匹配,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 识别为 \(\mathbb{R}^m\) 中的向量,并几乎可以互换地使用这两个向量,就像我们有时可能会在“列向量”和“行向量”之间来回切换一样,而没有太多注释。

通过这种识别,我们可以将 VJP 的线性部分视为 JVP 的线性部分的转置(或伴随共轭)

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).

对于给定的点 \(x\),我们可以将签名写成

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).

余切空间上的对应映射通常称为 \(f\)\(x\) 处的 拉回。对我们来说,关键是它从类似于 \(f\) 输出的东西到类似于 \(f\) 输入的东西,就像我们可能从转置的线性函数中预期的那样。

JAX 代码中的 VJP#

从数学切换回 Python,JAX 函数 vjp 可以获取用于评估 \(f\) 的 Python 函数,并为我们返回用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

类似 Haskell 的类型签名 而言,我们可以写成

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

其中我们使用 CT a 来表示 a 的余切空间类型。换句话说,vjp 以类型为 a -> b 的函数和类型为 a 的点作为参数,并返回一对,包含一个类型为 b 的值和一个类型为 CT b -> CT a 的线性映射。

这很棒,因为它允许我们一次构建一个雅可比矩阵,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本大约是评估 \(f\) 的成本的三倍。特别地,如果我们想要函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们可以通过一次调用来实现。这就是为什么 grad 对于基于梯度的优化非常有效,即使对于数百万或数十亿个参数上的神经网络训练损失函数等目标也是如此。

不过,也有一些成本:虽然 FLOP 很友好,但内存会随着计算的深度而扩展。此外,传统的实现比前向模式的实现更复杂,尽管 JAX 有一些技巧(这将是以后笔记本的内容!)。

有关反向模式工作原理的更多信息,请参阅 2017 年深度学习夏季学校的本教程视频

使用 VJP 的向量值梯度#

如果您有兴趣获取向量值梯度(例如 tf.gradients

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
 [6. 6.]]

使用前向模式和反向模式的 Hessian-向量积#

在上一节中,我们仅使用反向模式实现了 Hessian-向量积函数(假设连续二阶导数)

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这很有效率,但我们可以做得更好,并通过将前向模式与反向模式一起使用来节省一些内存。

在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\)、一个要线性化函数的点 \(x \in \mathbb{R}^n\) 和一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是

\((x, v) \mapsto \partial^2 f(x) v\)

考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),它被定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为它将为我们提供

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).

我们可以将它几乎直接转换为代码

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

更好的是,由于我们没有直接调用 jnp.dot,因此此 hvp 函数适用于任何形状的数组,并适用于任意容器类型(例如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy

这是一个如何使用它的示例

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

您可能考虑的另一种编写方式是使用反向-前向

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

但这不是最好的,因为前向模式的开销比反向模式低,并且由于这里的外层微分运算符必须对比内层运算符更复杂的计算进行微分,因此在外层保持前向模式效果最好

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
9.13 ms ± 101 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
19.3 ms ± 10.9 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
26.4 ms ± 15.5 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
76 ms ± 582 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

组合 VJP、JVP 和 vmap#

雅可比矩阵和矩阵-雅可比积#

现在我们有了 jvpvjp 变换,它们可以让我们得到函数来一次向前或向后推单个向量,我们可以使用 JAX 的 vmap 变换 来一次性向前或向后推整个基底。特别是,我们可以用它来编写快速的矩阵-雅可比和雅可比-矩阵乘积。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
397 ms ± 1.61 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
10.6 ms ± 188 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1490/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
421 ms ± 393 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
5.34 ms ± 162 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

jacfwdjacrev 的实现 #

现在我们已经看到了快速的雅可比-矩阵和矩阵-雅可比乘积,写出 jacfwdjacrev 就并不难。我们只需要使用同样的技巧来一次向前或向后推整个标准基底(与单位矩阵同构)。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once.
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

有趣的是,Autograd 无法做到这一点。我们在 Autograd 中反向模式 jacobian实现 必须使用外部循环 map 来一次向后拉一个向量。一次性将一个向量推过计算比将它们全部一起批处理使用 vmap 的效率要低得多。

Autograd 无法做到的另一件事是 jit。有趣的是,无论你在要微分的函数中使用了多少 Python 动态性,我们总是在计算的线性部分使用 jit。例如

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

复数和微分 #

JAX 非常擅长复数和微分。为了同时支持 全纯微分和非全纯微分,用 JVP 和 VJP 来思考会有所帮助。

考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\),并将它与一个对应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 标识。

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

也就是说,我们已经将 \(f(z) = u(x, y) + v(x, y) i\) 分解,其中 \(z = x + y i\),并将 \(\mathbb{C}\)\(\mathbb{R}^2\) 标识得到 \(g\)

由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为它编写雅可比-向量乘积,例如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

为了得到原始函数 \(f\) 对一个切向量 \(c + di \in \mathbb{C}\) 应用的 JVP,我们只需使用相同的定义并将结果识别为另一个复数,

\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数 JVP 的定义!注意,\(f\) 是否为全纯并不重要:JVP 是明确的。

这是一个检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

那么 VJP 呢?我们做的事情非常类似:对于一个余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为

\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).

为什么有负号?它们只是为了处理复共轭,以及我们正在处理余向量的事实。

这是一个 VJP 规则的检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)

那么像 gradjacfwdjacrev 这样的便捷包装器呢?

对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下,我们将 grad(f)(x) 定义为 vjp(f, x)[1](1.0),这是有效的,因为将 VJP 应用于一个 1.0 值会揭示梯度(即雅可比矩阵或导数)。我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数做同样的事情:我们仍然可以使用 1.0 作为余切向量,我们只是得到了一个总结完整雅可比矩阵的复数结果。

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)

对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比矩阵具有 4 个实数值自由度(如上面的 2x2 雅可比矩阵),因此我们不能希望将它们全部表示在一个复数中。但是,对于全纯函数,我们可以做到!全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,具有以下特殊性质:它的导数可以用单个复数表示。(柯西-黎曼方程 确保上面的 2x2 雅可比矩阵具有复平面中缩放和旋转矩阵的特殊形式,即复数在乘法下的作用。)我们可以通过使用 1.0 作为余向量的单个 vjp 调用来揭示这个复数。

因为这仅对全纯函数有效,所以为了使用这个技巧,我们需要向 JAX 保证我们的函数是全纯的;否则,当使用 grad 来处理复数输出函数时,JAX 会抛出错误。

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

holomorphic=True 所做的只是在输出是复数值时禁用错误。当函数不是全纯时,我们仍然可以写出 holomorphic=True,但我们得到的答案不会代表完整的雅可比矩阵。相反,它将是函数的雅可比矩阵,我们只是丢弃了输出的虚部。

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)

这里有一些关于 grad 如何工作时的有用结果

  1. 我们可以在全纯 \(\mathbb{C} \to \mathbb{C}\) 函数上使用 grad

  2. 我们可以使用 grad 来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数 x 的实值损失函数,通过沿着 grad(f)(x) 共轭方向进行迭代。

  3. 如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它恰好内部使用了一些复数值运算(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么 grad 仍然有效,我们得到的结果与仅使用实数值的实现相同。

无论如何,JVP 和 VJP 始终是明确的。如果我们想计算非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来做到!

你应该期望复数在 JAX 中无处不在。这里有一个对复数矩阵的 Cholesky 分解进行微分的例子

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940544j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64)

更高级的自动微分 #

在本笔记本中,我们通过了 JAX 中自动微分的一些简单,然后逐渐变得更加复杂的应用。我们希望你现在觉得在 JAX 中求导是容易且强大的。

还有很多其他的自动微分技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分食谱”中涵盖的主题包括

  • 高斯-牛顿向量乘积,线性化一次

  • 自定义 VJP 和 JVP

  • 固定点处的有效导数

  • 使用随机 Hessian-向量乘积估计 Hessian 的迹。

  • 仅使用反向模式自动微分进行正向模式自动微分。

  • 对自定义数据类型进行微分。

  • 检查点(二项式检查点以实现高效的反向模式,而不是模型快照)。

  • 使用雅可比预累积优化 VJP。