jax.jacfwd#
- jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
使用前向模式自动微分逐列计算
fun
的雅可比矩阵。- 参数:
- 返回值:
一个与
fun
具有相同参数的函数,它使用前向模式自动微分计算fun
的雅可比矩阵。如果has_aux
为 True,则返回 (雅可比矩阵, 辅助数据) 对。- 返回类型:
Callable
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]