jax.vjp

内容

jax.vjp#

jax.vjp(fun: Callable[..., T], *primals: Any, has_aux: Literal[False] = False, reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable][source]#
jax.vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable, U]

计算 fun 的(反向模式)向量-雅可比积。

grad() 是作为 vjp() 的一个特例实现的。

参数:
  • fun – 要进行微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应该返回一个数组、标量或数组或标量的标准 Python 容器。

  • primals – 一系列原始值,在这些值处应评估 fun 的雅可比矩阵。 primals 的数量应等于 fun 的位置参数的数量。每个原始值都应该是一个数组、一个标量或它们的 pytree(标准 Python 容器)。

  • has_aux – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

返回值:

如果 has_auxFalse,则返回一个 (primals_out, vjpfun) 对,其中 primals_outfun(*primals)。如果 has_auxTrue,则返回一个 (primals_out, vjpfun, aux) 元组,其中 auxfun 返回的辅助数据。

vjpfun 是一个从与 primals_out 形状相同的余切向量到与 primals 数量和形状相同的余切向量元组的函数,表示在 primals 处评估的 fun 的向量-雅可比积。

>>> import jax
>>>
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413