jax.numpy.ediff1d

内容

jax.numpy.ediff1d#

jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[source]#

计算扁平化数组元素的差值。

JAX 实现 numpy.ediff1d().

参数:
  • ary (ArrayLike) – 输入数组或标量。

  • to_end (ArrayLike | None) – 标量或数组,可选,默认为 None。指定要追加到结果数组的数字。

  • to_begin (ArrayLike | None) – 标量或数组,可选,默认为 None。指定要预先添加到结果数组的数字。

返回值:

包含输入数组元素之间差值的数组。

返回类型:

数组

注意

与 NumPy 中 ediff1d 的实现不同,jax.numpy.ediff1d() 不会在将 to_endto_begin 转换为 ary 类型时丢失精度的情况下报错。

另请参阅

示例

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

对于 ndim > 1 的数组,差分是在展平输入数组后计算的。

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)