jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[源代码]#

计算采样函数的数值梯度。

numpy.gradient() 的 JAX 实现。

jnp.gradient 中的梯度是使用采样函数值数组上的二阶有限差分计算的。这不应与 jax.grad() 混淆,后者通过自动微分计算可调用函数的精确梯度。

参数:
  • f (ArrayLike) – 函数值的 N 维数组。

  • varargs (ArrayLike) –

    可选的标量或数组列表,指定函数评估的间隔。选项包括

    • 未指定:所有维度上的单位间隔。

    • 单个标量:所有维度上的恒定间隔。

    • N 个值:指定每个维度上的不同间隔

      • 标量值表示该维度上的恒定间隔。

      • 数组值必须与相应维度的长度匹配,并指定评估 f 的坐标。

  • edge_order (int | None) – 在 JAX 中未实现

  • axis (int | Sequence[int] | None) – 指定计算梯度的轴的整数或整数元组。如果为 None(默认),则沿所有轴计算梯度。

返回:

包含沿每个指定轴的数值梯度的数组或数组元组。

返回类型:

Array | list[Array]

另请参阅

  • jax.grad():对具有单个输出的函数进行自动微分。

示例

比较简单函数的数值微分和自动微分

>>> def f(x):
...   return jnp.sin(x) * jnp.exp(-x / 4)
...
>>> def gradf_exact(x):
...   # exact analytical gradient of f(x)
...   return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4)
...
>>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print("numerical gradient:", jnp.gradient(f(x), x))
...   print("automatic gradient:", jax.vmap(jax.grad(f))(x))
...   print("exact gradient:    ", gradf_exact(x))
...
numerical gradient: [ 0.83  0.61  0.18 -0.2  -0.43 -0.49 -0.39 -0.21 -0.02  0.08]
automatic gradient: [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]
exact gradient:     [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]

请注意,正如预期的那样,与通过 jax.grad() 计算的自动梯度相比,数值梯度存在一些近似误差。