jax.numpy.gradient

目录

jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[source]#

返回 N 维数组的梯度。

LAX 后端实现 numpy.gradient().

原始文档字符串如下。

梯度使用二阶精确中心差分在内部点计算,使用一阶或二阶精确单边(向前或向后)差分在边界计算。因此,返回的梯度与输入数组具有相同的形状。

参数:
  • f (array_like) – 包含标量函数样本的 N 维数组。

  • varargs (list of scalar or array, optional) –

    f 值之间的间距。所有维度的默认单位间距。可以使用

    1. 单个标量指定所有维度的样本距离。

    2. N 个标量指定每个维度的常数样本距离。即 dxdydz,…

    3. N 个数组指定 F 沿每个维度值的坐标。数组的长度必须与相应维度的尺寸匹配

    4. 2. 和 3. 中 N 个标量/数组的任何组合。

    如果提供了 axis,则变长参数的数量必须等于轴的数量。默认值:1。(参见下面的示例)。

  • axis (Noneinttuple of ints, 可选) – 仅沿给定轴或轴计算梯度。默认值(axis = None)是对输入数组的所有轴计算梯度。axis 可以为负数,在这种情况下它从最后一个轴到第一个轴计数。

  • edge_order (int | None)

返回值:

gradient – 一个 ndarray 元组(如果只有一个维度,则为单个 ndarray),对应于 f 相对于每个维度的导数。每个导数都与 f 具有相同的形状。

返回类型:

ndarray 或 tuple of ndarray

参考文献