jax.linearize#

jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable][source]#
jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]

使用 jvp() 和部分求值,生成函数 fun 的线性近似。

参数:
  • fun – 要微分的函数。它的参数应该是数组、标量或数组或标量的标准 Python 容器。它应该返回一个数组、标量或数组或标量的标准 Python 容器。

  • primals – 应该在其中评估 fun 的雅可比矩阵的原始值。 应该是数组、标量或其标准 Python 容器的元组。该元组的长度等于 fun 的位置参数的数量。

  • has_aux – 可选,布尔值。 指示 fun 是否返回一个对,其中第一个元素被认为是线性化的数学函数的输出,第二个是辅助数据。 默认值为 False。

返回值:

如果 has_auxFalse,则返回一个对,其中第一个元素是 f(*primals) 的值,第二个元素是一个函数,用于评估在 primals 处评估的 fun 的(前向模式)雅可比向量积,而无需重新进行线性化工作。 如果 has_auxTrue,则返回一个 (primals_out, lin_fn, aux) 元组,其中 auxfun 返回的辅助数据。

就计算值而言,linearize() 的行为很像柯里化的 jvp(),其中以下两个代码块计算相同的值

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)

但是,区别在于 linearize() 使用部分求值,以便在调用 f_jvp 时不会重新线性化函数 f。通常,这意味着内存使用量会随着计算量的大小而缩放,这很像反向模式。(实际上,linearize() 具有与 vjp() 类似的签名!)

如果你想多次应用 f_jvp,即在相同的线性化点评估许多不同的输入切向量的推送,那么此函数主要有用。 此外,如果所有输入切向量都一次性已知,则使用 vmap() 进行矢量化可能更有效,如下所示

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))

通过像这样一起使用 vmap()jvp(),我们可以避免存储的线性化内存成本,该成本会随着计算的深度而缩放,linearize()vjp() 都会产生这种成本。

这是使用 linearize() 的更完整示例

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704