jax.experimental.jet
模块#
Jet 是一个用于高阶自动微分的实验性模块,它不依赖于重复的一阶自动微分。
原理是什么?通过传播截断的泰勒多项式。考虑一个函数 \(f = g \circ h\),某个点 \(x\) 和某个偏移量 \(v\)。一阶自动微分(例如 jax.jvp()
)从对 \((h(x), \partial h(x)[v])\) 计算对 \((f(x), \partial f(x)[v])\)。
jet()
实现了高阶类似物:给定元组
它表示 \(h\) 在 \(x\) 处的 \(K\) 阶泰勒近似,jet()
返回 \(f\) 在 \(x\) 处的 \(K\) 阶泰勒近似,
更具体地说,jet()
计算
因此可以用于 \(f\) 的高阶自动微分。详细信息请参考这些注释。
API#
- jax.experimental.jet.jet(fun, primals, series)[源代码]#
泰勒模式高阶自动微分。
- 参数:
fun – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应返回数组、标量或数组或标量的标准 Python 容器。
primals – 应该评估
fun
的泰勒近似的原始值。它应该是一个元组或一个参数列表,其长度应等于fun
的位置参数的数量。series – 高阶泰勒级数系数。 primals 和 series 一起构成截断的泰勒多项式。它应该是一个元组或一个元组或列表的列表,其长度决定了截断的泰勒多项式的阶数。
- 返回:
一个
(primals_out, series_out)
对,其中primals_out
是fun(*primals)
,并且primals_out
和series_out
一起是 \(f(h(\cdot))\) 的截断泰勒多项式。primals_out
值具有与primals
相同的 Python 树结构,而series_out
值具有与series
相同的 Python 树结构。
例如
>>> import jax >>> import jax.numpy as np
考虑函数 \(h(z) = z^3\),\(x = 0.5\),以及前几个泰勒系数 \(h_0=x^3\),\(h_1=3x^2\) 和 \(h_2=6x\)。令 \(f(y) = \sin(y)\)。
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
jet()
根据 Faà di Bruno 公式返回 \(f(h(z)) = \sin(z^3)\) 的泰勒系数>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473
>>> print(f1, df(h0) * h1) 0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634