jax.numpy.gradient#
- jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[source]#
返回 N 维数组的梯度。
LAX 后端实现
numpy.gradient()
.原始文档字符串如下。
梯度使用二阶精确中心差分在内部点计算,使用一阶或二阶精确单边(向前或向后)差分在边界计算。因此,返回的梯度与输入数组具有相同的形状。
- 参数:
f (array_like) – 包含标量函数样本的 N 维数组。
varargs (list of scalar or array, optional) –
f 值之间的间距。所有维度的默认单位间距。可以使用
单个标量指定所有维度的样本距离。
N 个标量指定每个维度的常数样本距离。即 dx,dy,dz,…
N 个数组指定 F 沿每个维度值的坐标。数组的长度必须与相应维度的尺寸匹配
2. 和 3. 中 N 个标量/数组的任何组合。
如果提供了 axis,则变长参数的数量必须等于轴的数量。默认值:1。(参见下面的示例)。
axis (None 或 int 或 tuple of ints, 可选) – 仅沿给定轴或轴计算梯度。默认值(axis = None)是对输入数组的所有轴计算梯度。axis 可以为负数,在这种情况下它从最后一个轴到第一个轴计数。
edge_order (int | None)
- 返回值:
gradient – 一个 ndarray 元组(如果只有一个维度,则为单个 ndarray),对应于 f 相对于每个维度的导数。每个导数都与 f 具有相同的形状。
- 返回类型:
ndarray 或 tuple of ndarray
参考文献