自动微分手册#
JAX 有一个相当通用的自动微分系统。在本笔记本中,我们将介绍一系列你可以为自己的工作选择的简洁的自动微分概念,从基础开始。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
梯度#
从 grad
开始#
你可以使用 grad
对函数求微分
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
接受一个函数并返回一个函数。如果你有一个 Python 函数 f
,它评估数学函数 \(f\),那么 grad(f)
是一个 Python 函数,它评估数学函数 \(\nabla f\)。这意味着 grad(f)(x)
表示值 \(\nabla f(x)\)。
由于 grad
对函数进行操作,你可以将其应用于其自身的输出来进行多次微分。
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
让我们看看在线性逻辑回归模型中使用 grad
计算梯度。首先,进行设置。
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
使用 grad
函数及其 argnums
参数来微分函数相对于位置参数。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
这个 grad
API 与斯皮瓦克经典著作《流形上的微积分》(1965) 中出色的符号表示直接对应,该符号也用于萨斯曼和威斯顿的《经典力学的结构和解释》(2015) 及其《泛函微分几何》(2013)。这两本书都是开放获取的。请特别参阅《泛函微分几何》的“序言”部分,其中对这种符号表示进行了辩护。
本质上,当使用 argnums
参数时,如果 f
是用于评估数学函数 \(f\) 的 Python 函数,则 Python 表达式 grad(f, i)
会求值为一个用于评估 \(\partial_i f\) 的 Python 函数。
对嵌套列表、元组和字典求微分#
对标准 Python 容器求微分是可以直接进行的,因此可以随意使用元组、列表和字典(以及任意嵌套)。
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
你可以注册你自己的容器类型,以便不仅与 grad
一起使用,而且还可以与所有 JAX 转换(jit
、vmap
等)一起使用。
使用 value_and_grad
评估函数及其梯度#
另一个方便的函数是 value_and_grad
,它用于高效地计算函数的值及其梯度值。
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
与数值差异进行比较#
关于导数的一个优点是,它们可以通过有限差分进行直接检查。
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX 提供了一个简单的便利函数,它执行的操作基本相同,但会检查你喜欢的任何微分阶数。
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
使用 grad
的 grad
求 Hessian-向量积#
我们可以使用高阶 grad
做的一件事是构建一个 Hessian-向量积函数。(稍后我们将编写一个更有效的实现,该实现结合了正向模式和反向模式,但这个将使用纯反向模式。)
Hessian-向量积函数在截断牛顿共轭梯度算法中可以用于最小化光滑凸函数,或者用于研究神经网络训练目标的曲率(例如,1、2、3、4)。
对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 可以写为 \(\partial^2 f(x)\)。然后,Hessian-向量积函数可以评估
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任何 \(v \in \mathbb{R}^n\)。
诀窍是不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,可能在神经网络的上下文中达到数百万或数十亿,那么可能无法存储它。
幸运的是,grad
已经为我们提供了一种编写高效 Hessian-向量积函数的方法。我们只需要使用恒等式
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\) 在 \(x\) 处的梯度与向量 \(v\) 进行点积运算。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad
有效的地方。
在 JAX 代码中,我们可以这样编写
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这个例子表明你可以自由使用词法闭包,而 JAX 永远不会感到困扰或困惑。
在接下来的几个单元格中,我们将检查此实现,一旦我们看到如何计算稠密 Hessian 矩阵。我们还将编写一个更好的版本,该版本同时使用正向模式和反向模式。
使用 jacfwd
和 jacrev
求雅可比矩阵和 Hessian 矩阵#
你可以使用 jacfwd
和 jacrev
函数计算完整的雅可比矩阵。
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188288 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
这两个函数计算相同的值(达到机器数值),但在实现方式上有所不同:jacfwd
使用正向模式自动微分,对于“高”雅可比矩阵(输出多于输入)来说效率更高,而 jacrev
使用反向模式,对于“宽”雅可比矩阵(输入多于输出)来说效率更高。对于接近正方形的矩阵,jacfwd
可能比 jacrev
更有优势。
你还可以将 jacfwd
和 jacrev
与容器类型一起使用
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]
有关正向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwd
和 jacrev
,请继续阅读!
组合使用这两个函数中的两个函数可以为我们提供一种计算稠密 Hessian 矩阵的方法。
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465 0.04922541 0.03384247]
[ 0.04922541 0.10602397 0.07289147]
[ 0.03384247 0.07289147 0.05011288]]
[[-0.03195215 0.03921401 -0.00544639]
[ 0.03921401 -0.04812629 0.00668421]
[-0.00544639 0.00668421 -0.00092836]]
[[-0.01583708 -0.00182736 0.03959271]
[-0.00182736 -0.00021085 0.00456839]
[ 0.03959271 0.00456839 -0.09898177]]
[[-0.00103524 0.00348343 -0.00194457]
[ 0.00348343 -0.01172127 0.0065432 ]
[-0.00194457 0.0065432 -0.00365263]]]
这种形状是有意义的:如果我们从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们希望得到以下形状
\(f(x) \in \mathbb{R}^m\),即 \(f\) 在 \(x\) 处的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),即 \(x\) 处的雅可比矩阵,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),即 \(x\) 处的 Hessian 矩阵,
等等。
要实现 hessian
,我们可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或这两个函数的任何其他组合。但正向模式优于反向模式通常效率最高。这是因为在内部雅可比矩阵计算中,我们通常会微分一个宽雅可比矩阵的函数(可能像一个损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部雅可比矩阵计算中,我们会微分一个具有正方形雅可比矩阵的函数(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这正是正向模式胜出的地方。
它的原理:两个基础的自动微分函数#
雅可比矩阵-向量积 (JVP,又名正向模式自动微分)#
JAX 包含了前向和反向模式自动微分的高效且通用的实现。我们熟悉的 grad
函数是基于反向模式构建的,但是为了解释这两种模式之间的区别,以及每种模式的用途,我们需要一些数学背景知识。
数学中的 JVP#
在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),\(f\) 在输入点 \(x \in \mathbb{R}^n\) 处的雅可比矩阵,记为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的一个矩阵。
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但我们也可以将 \(\partial f(x)\) 视为一个线性映射,它将 \(f\) 的域在点 \(x\) 处的切空间(只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 的值域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的一个副本)。
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
此映射称为 \(f\) 在 \(x\) 处的前推映射。雅可比矩阵只是这个线性映射在标准基下的矩阵。
如果我们不确定一个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为首先接收一个输入点,然后返回该输入点处的雅可比线性映射。
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特别地,我们可以解柯里化,以便在给定输入点 \(x \in \mathbb{R}^n\) 和一个切向量 \(v \in \mathbb{R}^n\) 时,我们得到 \(\mathbb{R}^m\) 中的一个输出切向量。我们将从 \((x, v)\) 对到输出切向量的映射称为雅可比-向量积,并将其写为
\(\qquad (x, v) \mapsto \partial f(x) v\)
JAX 代码中的 JVP#
回到 Python 代码,JAX 的 jvp
函数模拟了这个转换。给定一个计算 \(f\) 的 Python 函数,JAX 的 jvp
是一种获取用于计算 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用类似 Haskell 的类型签名来说,我们可以写成
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中我们使用 T a
来表示 a
的切空间的类型。换句话说,jvp
接收类型为 a -> b
的函数、类型为 a
的值以及类型为 T a
的切向量值作为参数。它返回一个包含类型为 b
的值和类型为 T b
的输出切向量的对。
jvp
转换后的函数与原始函数非常相似,但它会与类型为 a
的每个原始值配对,并沿着类型为 T a
的切线值进行推送。对于原始函数应用的每个原始数值运算,jvp
转换后的函数会针对该原始值执行一个“JVP 规则”,该规则既评估原始值上的原始值,又在该原始值上应用原始值的 JVP。
该评估策略对计算复杂度有一些直接影响:由于我们在进行评估时计算 JVP,因此我们无需存储任何内容以供以后使用,因此内存成本与计算深度无关。此外,jvp
转换后函数的 FLOP 成本约为仅评估函数成本的 3 倍(评估原始函数的一个单位工作量,例如 sin(x)
;一个用于线性化的单位,如 cos(x)
;以及一个用于将线性化函数应用于向量的单位,如 cos_x * v
)。换句话说,对于固定的原始点 \(x\),我们可以评估 \(v \mapsto \partial f(x) \cdot v\),其边际成本与评估 \(f\) 大致相同。
这种内存复杂性听起来很有吸引力!那么为什么我们在机器学习中不经常看到前向模式呢?
要回答这个问题,首先考虑一下如何使用 JVP 构建完整的雅可比矩阵。如果我们将 JVP 应用于单热切向量,它会显示雅可比矩阵的一列,对应于我们输入的非零项。因此,我们可以一次构建一列完整的雅可比矩阵,获取每一列的成本与一次函数评估的成本大致相同。这对于具有“高”雅可比矩阵的函数来说是高效的,但对于“宽”雅可比矩阵来说是低效的。
如果您在机器学习中进行基于梯度的优化,您可能希望将 \(\mathbb{R}^n\) 中的参数的损失函数最小化为 \(\mathbb{R}\) 中的标量损失值。这意味着此函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其与梯度向量 \(\nabla f(x) \in \mathbb{R}^n\) 标识。逐列构建该矩阵,每次调用都花费与评估原始函数相似的 FLOP 数量,这看起来确实效率低下!特别是,对于训练神经网络,其中 \(f\) 是训练损失函数,而 \(n\) 可能达到数百万或数十亿,这种方法根本无法扩展。
为了更好地处理此类函数,我们只需要使用反向模式。
向量-雅可比积(VJP,也称为反向模式自动微分)#
前向模式为我们提供了一个用于评估雅可比-向量积的函数,然后我们可以使用它来一次构建雅可比矩阵的一列,而反向模式是一种获取用于评估向量-雅可比积(等效于雅可比-转置-向量积)的函数的方法,我们可以使用它来一次构建雅可比矩阵的一行。
数学中的 VJP#
让我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们对 JVP 的表示法开始,VJP 的表示法非常简单。
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 处的余切空间(与 \(\mathbb{R}^m\) 的另一个副本同构)的元素。在严格的情况下,我们应该将 \(v\) 视为线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们的意思是函数组合 \(v \circ \partial f(x)\),其中类型是有效的,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 与 \(\mathbb{R}^m\) 中的向量进行标识,并且几乎可以互换地使用两者,就像我们有时可能会在“列向量”和“行向量”之间切换而无需太多注释一样。
通过这种识别,我们可以将 VJP 的线性部分替代地视为 JVP 线性部分的转置(或伴随共轭)
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
对于给定的点 \(x\),我们可以将签名写为
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
余切空间上的对应映射通常称为 \(f\) 在 \(x\) 处的拉回。对于我们的目的而言,关键在于它从看起来像 \(f\) 输出的内容转变为看起来像 \(f\) 输入的内容,正如我们对转置线性函数的期望一样。
JAX 代码中的 VJP#
从数学切换回 Python,JAX 函数 vjp
可以获取一个用于评估 \(f\) 的 Python 函数,并返回一个用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
用类似 Haskell 的类型签名来说,我们可以写成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我们使用 CT a
来表示 a
的余切空间的类型。换句话说,vjp
接收一个类型为 a -> b
的函数和一个类型为 a
的点作为参数,并返回一个包含类型为 b
的值和类型为 CT b -> CT a
的线性映射的对。
这非常棒,因为它允许我们一次构建雅可比矩阵的一行,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本仅约为评估 \(f\) 的成本的三倍。特别是,如果我们想要函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们只需调用一次即可完成。这就是为什么 grad
对于基于梯度的优化是高效的,即使对于数百万或数十亿参数的神经网络训练损失函数等目标也是如此。
不过,这也有一个代价:虽然 FLOP 很友好,但内存会随着计算深度的增加而增加。此外,该实现传统上比前向模式的实现更复杂,尽管 JAX 有一些技巧(这是一个未来笔记本的故事!)。
有关反向模式工作原理的更多信息,请参阅2017 年深度学习暑期学校的本教程视频。
使用 VJP 进行向量值梯度计算#
如果您有兴趣采用向量值梯度(如 tf.gradients
)
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用前向和反向模式的 Hessian-向量积#
在前面的部分中,我们仅使用反向模式(假设二阶导数连续)实现了一个 Hessian-向量积函数
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这很高效,但我们可以通过将前向模式与反向模式结合使用,从而做得更好并节省一些内存。
在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个线性化函数的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是
\((x, v) \mapsto \partial^2 f(x) v\)
考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为它会给我们
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我们可以将其几乎直接翻译成代码
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更棒的是,由于我们不必直接调用 jnp.dot
,所以这个 hvp
函数可以处理任何形状的数组和任意容器类型(如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy
。
这是一个如何使用它的例子
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
您可能考虑编写此代码的另一种方法是使用反向覆盖前向
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
但这不太好,因为前向模式的开销比反向模式小,并且由于此处外部微分算子必须微分比内部算子更大的计算,因此保持外部使用前向模式效果最佳
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
2.72 ms ± 42 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 4.46 times longer than the fastest. This could mean that an intermediate result is being cached.
7.45 ms ± 5.63 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 4.66 times longer than the fastest. This could mean that an intermediate result is being cached.
11.5 ms ± 8.82 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
37.6 ms ± 1.85 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组合 VJP、JVP 和 vmap
#
雅可比矩阵和矩阵-雅可比积#
现在我们有了 jvp
和 vjp
变换,它们可以为我们提供一次向前推送或向后拉单个向量的函数,我们可以使用 JAX 的 vmap
变换 一次性推送和拉取整个基。特别是,我们可以使用它来编写快速的矩阵-雅可比积和雅可比-矩阵积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
94.9 ms ± 622 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
3.22 ms ± 55.7 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1241/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
124 ms ± 329 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
1.49 ms ± 34.6 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jacfwd
和 jacrev
的实现#
既然我们已经看到了快速的雅可比-矩阵积和矩阵-雅可比积,那么不难猜出如何编写 jacfwd
和 jacrev
。我们只是使用相同的技术一次性推送或拉取整个标准基(同构于单位矩阵)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 做不到这一点。我们在 Autograd 中反向模式 jacobian
的实现 必须使用外循环 map
一次拉回一个向量。一次通过计算推送一个向量远不如使用 vmap
将它们全部批量处理有效。
Autograd 做不到的另一件事是 jit
。有趣的是,无论您在要微分的函数中使用多少 Python 动态特性,我们始终可以在计算的线性部分使用 jit
。例如
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数和微分#
JAX 非常擅长处理复数和微分。为了支持全纯和非全纯微分,从 JVP 和 VJP 的角度进行思考会有所帮助。
考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\) 并将其与相应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 相对应,
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们已经将 \(f(z) = u(x, y) + v(x, y) i\) 分解,其中 \(z = x + y i\),并将 \(\mathbb{C}\) 与 \(\mathbb{R}^2\) 识别以获得 \(g\)。
由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为其编写雅可比-向量积,例如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
为了获得应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果识别为另一个复数,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 的定义!请注意,\(f\) 是否为全纯函数无关紧要:JVP 是明确的。
这是一个检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
那 VJP 呢?我们做的事情非常相似:对于余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
这些负号是怎么回事?它们只是为了处理复共轭,以及我们正在使用共向量的事实。
这是一个 VJP 规则的检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
那么诸如 grad
、jacfwd
和 jacrev
之类的便捷包装器呢?
对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下,我们将 grad(f)(x)
定义为 vjp(f, x)[1](1.0)
,之所以可行是因为将 VJP 应用于 1.0
值会揭示梯度(即雅可比,或导数)。我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数做同样的事情:我们仍然可以使用 1.0
作为余切向量,并且我们只会得到一个总结完整雅可比的复数结果
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比具有 4 个实数值自由度(如上面的 2x2 雅可比矩阵),因此我们无法希望在复数中表示所有这些自由度。但是我们可以为全纯函数做到这一点!全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,其具有特殊属性,即其导数可以表示为单个复数。(柯西-黎曼方程 确保上面的 2x2 雅可比在复平面中具有缩放和旋转矩阵的特殊形式,即单个复数在乘法下的作用。)并且我们可以使用对 vjp
的单个调用并使用 1.0
的共向量来揭示该复数。
因为这仅适用于全纯函数,所以要使用此技巧,我们需要向 JAX 保证我们的函数是全纯函数;否则,当 grad
用于复数输出函数时,JAX 将引发错误
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True
承诺所做的是禁用输出为复数值时的错误。当函数不是全纯函数时,我们仍然可以编写 holomorphic=True
,但是我们得到的结果不会表示完整的雅可比。相反,它将是函数的雅可比,我们在其中仅丢弃输出的虚部
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
对于 grad
在这里的工作方式,有一些有用的结果
我们可以在全纯的 \(\mathbb{C} \to \mathbb{C}\) 函数上使用
grad
。我们可以使用
grad
来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数x
的实值损失函数,方法是沿grad(f)(x)
的共轭方向迈进。如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它恰好在内部使用一些复数值运算(其中一些必须是非全纯的,例如卷积中使用的 FFT),则
grad
仍然有效,并且我们得到的结果与仅使用实数值的实现会给出的结果相同。
在任何情况下,JVP 和 VJP 始终是明确的。如果我们想计算非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来完成!
您应该期望复数在 JAX 中的任何地方都可以使用。这是对复矩阵的 Cholesky 分解进行微分
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
更高级的自动微分#
在本笔记本中,我们研究了 JAX 中自动微分的一些简单,然后逐渐变得更复杂的应用。我们希望您现在感觉在 JAX 中进行微分既简单又强大。
那里还有许多其他自动微分技巧和功能。我们未涵盖但在“高级自动微分食谱”中希望涵盖的主题包括
高斯-牛顿向量积,一次线性化
自定义 VJP 和 JVP
定点处的有效导数
使用随机 Hessian-向量积估计 Hessian 的迹。
仅使用反向模式自动微分实现前向模式自动微分。
针对自定义数据类型求导。
检查点(用于高效反向模式的二项式检查点,而非模型快照)。
使用雅可比矩阵预累积优化VJPs。