jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[源代码]#

计算给定轴上数组元素之间的 n 阶差分。

JAX 实现的 numpy.diff()

一阶差分通过 a[i+1] - a[i] 计算,而 n 阶差分则递归计算 n 次。

参数:
  • a (类似数组) – 输入数组。必须满足 a.ndim >= 1

  • n (int) – int,可选,默认值=1。差分的阶数。指定计算差分的次数。如果 n=0,则不计算差分,并按原样返回输入。

  • axis (int) – int,可选,默认值=-1。指定计算差分的轴。默认沿 axis -1 计算差分。

  • prepend (类似数组 | None) – 标量或数组,可选,默认值=None。指定在计算差分之前沿 axis 预先添加的值。

  • append (类似数组 | None) – 标量或数组,可选,默认值=None。指定在计算差分之前沿 axis 追加的值。

返回:

一个包含 a 的元素之间的 n 阶差分的数组。

返回类型:

数组

另请参阅

示例

jnp.diff 默认计算沿 axis 的一阶差分。

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

n = 2 时,计算沿 axis 的二阶差分。

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

prepend = 2 时,在计算差分之前,它会沿 axis 预添加到 a

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

append = jnp.array([[3],[1]]) 时,在计算差分之前,它会沿 axis 追加到 a

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)