jax.jvp#

jax.jvp(fun, primals, tangents, has_aux=False)[source]#

计算 fun 的(前向模式)雅可比-向量积。

参数:
  • fun (Callable) – 需要进行微分的函数。它的参数应该是数组、标量,或者是数组或标量的标准 Python 容器。它应该返回一个数组、标量,或者是数组或标量的标准 Python 容器。

  • primals – 应该计算 fun 的雅可比矩阵的原始值。应该是一个参数的元组或列表,其长度应该等于 fun 的位置参数的数量。

  • tangents – 应该计算雅可比矩阵-向量积的切向量。应该是一个切向量的元组或列表,具有与 primals 相同的树结构和数组形状。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被认为是需要进行微分的数学函数的输出,而第二个元素是辅助数据。默认为 False。

返回值:

如果 has_auxFalse,则返回一个 (primals_out, tangents_out) 对,其中 primals_outfun(*primals),而 tangents_out 是在 primals 处使用 tangents 计算的 function 的雅可比矩阵-向量积。tangents_out 值具有与 primals_out 相同的 Python 树结构和形状。如果 has_auxTrue,则返回一个 (primals_out, tangents_out, aux) 元组,其中 aux 是由 fun 返回的辅助数据。

返回类型:

tuple[Any, …]

例如

>>> import jax
>>>
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(primals)
0.09983342
>>> print(tangents)
0.19900084