jax.numpy.diff#
- jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[源代码]#
计算给定轴上数组元素之间的 n 阶差分。
JAX 实现的
numpy.diff()
。一阶差分通过
a[i+1] - a[i]
计算,而 n 阶差分则递归计算n
次。- 参数:
- 返回:
一个包含
a
的元素之间的 n 阶差分的数组。- 返回类型:
另请参阅
jax.numpy.ediff1d()
:计算数组中连续元素之间的差值。jax.numpy.cumsum()
:计算沿给定轴的数组元素的累积和。jax.numpy.gradient()
:计算 N 维数组的梯度。
示例
jnp.diff
默认计算沿axis
的一阶差分。>>> a = jnp.array([[1, 5, 2, 9], ... [3, 8, 7, 4]]) >>> jnp.diff(a) Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
当
n = 2
时,计算沿axis
的二阶差分。>>> jnp.diff(a, n=2) Array([[-7, 10], [-6, -2]], dtype=int32)
当
prepend = 2
时,在计算差分之前,它会沿axis
预添加到a
。>>> jnp.diff(a, prepend=2) Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
当
append = jnp.array([[3],[1]])
时,在计算差分之前,它会沿axis
追加到a
。>>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)