jax.numpy.ediff1d#

jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[源代码]#

计算展平数组的元素之间的差异。

numpy.ediff1d() 的 JAX 实现。

参数
  • ary (ArrayLike) – 输入数组或标量。

  • to_end (ArrayLike | None) – 标量或数组,可选,默认=None。指定要附加到结果数组的数字。

  • to_begin (ArrayLike | None) – 标量或数组,可选,默认=None。指定要添加到结果数组开头的数字。

返回

一个数组,包含输入数组元素之间的差异。

返回类型

数组

注意

与 NumPy 的 ediff1d 实现不同,如果将 to_endto_begin 转换为 ary 的类型时发生精度损失,jax.numpy.ediff1d() 不会发出错误。

另请参阅

示例

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

对于 ndim > 1 的数组,差异是在展平输入数组后计算的。

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)