jax.jacfwd

内容

jax.jacfwd#

jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

使用前向模式自动微分逐列计算 fun 的雅可比矩阵。

参数:
  • fun (Callable) – 要计算雅可比矩阵的函数。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要对其求导的位置参数(默认为 0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic (bool) – 可选,布尔值。指示 fun 是否保证为全纯函数。默认为 False。

返回值:

一个与 fun 具有相同参数的函数,它使用前向模式自动微分计算 fun 的雅可比矩阵。如果 has_aux 为 True,则返回 (雅可比矩阵, 辅助数据) 对。

返回类型:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]