jax.jvp#
- jax.jvp(fun, primals, tangents, has_aux=False)[source]#
计算
fun
的(前向模式)雅可比-向量积。- 参数:
fun (Callable) – 需要进行微分的函数。它的参数应该是数组、标量,或者是数组或标量的标准 Python 容器。它应该返回一个数组、标量,或者是数组或标量的标准 Python 容器。
primals – 应该计算
fun
的雅可比矩阵的原始值。应该是一个参数的元组或列表,其长度应该等于fun
的位置参数的数量。tangents – 应该计算雅可比矩阵-向量积的切向量。应该是一个切向量的元组或列表,具有与
primals
相同的树结构和数组形状。has_aux (bool) – 可选,布尔值。指示
fun
是否返回一个对,其中第一个元素被认为是需要进行微分的数学函数的输出,而第二个元素是辅助数据。默认为 False。
- 返回值:
如果
has_aux
为False
,则返回一个(primals_out, tangents_out)
对,其中primals_out
是fun(*primals)
,而tangents_out
是在primals
处使用tangents
计算的function
的雅可比矩阵-向量积。tangents_out
值具有与primals_out
相同的 Python 树结构和形状。如果has_aux
为True
,则返回一个(primals_out, tangents_out, aux)
元组,其中aux
是由fun
返回的辅助数据。- 返回类型:
tuple[Any, …]
例如
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084