jax.numpy.diff#
- jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[source]#
沿给定轴计算数组元素之间的 n 阶差。
JAX 实现
numpy.diff()
.一阶差由
a[i+1] - a[i]
计算,n 阶差通过递归计算n
次。- 参数::
a (ArrayLike) – 输入数组。必须满足
a.ndim >= 1
.n (int) – int,可选,默认为 1。差的阶数。指定计算差的次数。如果 n=0,则不计算差,并按原样返回输入。
axis (int) – int,可选,默认值为-1。指定计算差值的轴。默认情况下,差值是在
axis -1
轴上计算的。prepend (ArrayLike | None) – 标量或数组,可选,默认值为None。指定在计算差值之前沿
axis
轴预先添加的值。append (ArrayLike | None) – 标量或数组,可选,默认值为None。指定在计算差值之前沿
axis
轴追加的值。
- 返回值:
包含
a
元素之间 n 阶差值的数组。- 返回类型:
另请参阅
jax.numpy.ediff1d()
: 计算数组中连续元素之间的差值。jax.numpy.cumsum()
: 计算数组沿给定轴的元素的累积和。jax.numpy.gradient()
: 计算 N 维数组的梯度。
示例
jnp.diff
计算axis
轴上的 一阶差值,默认情况下。>>> a = jnp.array([[1, 5, 2, 9], ... [3, 8, 7, 4]]) >>> jnp.diff(a) Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
当
n = 2
时,计算axis
轴上的 二阶差值。>>> jnp.diff(a, n=2) Array([[-7, 10], [-6, -2]], dtype=int32)
当
prepend = 2
时,它会在计算差值之前沿axis
轴预先添加到a
中。>>> jnp.diff(a, prepend=2) Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
当
append = jnp.array([[3],[1]])
时,它会在计算差值之前沿axis
轴追加到a
中。>>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)