jax.numpy.diff

目录

jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[source]#

沿给定轴计算数组元素之间的 n 阶差。

JAX 实现 numpy.diff().

一阶差由 a[i+1] - a[i] 计算,n 阶差通过递归计算 n 次。

参数::
  • a (ArrayLike) – 输入数组。必须满足 a.ndim >= 1.

  • n (int) – int,可选,默认为 1。差的阶数。指定计算差的次数。如果 n=0,则不计算差,并按原样返回输入。

  • axis (int) – int,可选,默认值为-1。指定计算差值的轴。默认情况下,差值是在 axis -1 轴上计算的。

  • prepend (ArrayLike | None) – 标量或数组,可选,默认值为None。指定在计算差值之前沿 axis 轴预先添加的值。

  • append (ArrayLike | None) – 标量或数组,可选,默认值为None。指定在计算差值之前沿 axis 轴追加的值。

返回值:

包含 a 元素之间 n 阶差值的数组。

返回类型:

数组

另请参阅

示例

jnp.diff 计算 axis 轴上的 一阶差值,默认情况下。

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

n = 2 时,计算 axis 轴上的 二阶差值。

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

prepend = 2 时,它会在计算差值之前沿 axis 轴预先添加到 a 中。

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

append = jnp.array([[3],[1]]) 时,它会在计算差值之前沿 axis 轴追加到 a 中。

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)