jax.numpy.ediff1d#
- jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[source]#
计算扁平化数组元素的差值。
JAX 实现
numpy.ediff1d()
.- 参数:
ary (ArrayLike) – 输入数组或标量。
to_end (ArrayLike | None) – 标量或数组,可选,默认为 None。指定要追加到结果数组的数字。
to_begin (ArrayLike | None) – 标量或数组,可选,默认为 None。指定要预先添加到结果数组的数字。
- 返回值:
包含输入数组元素之间差值的数组。
- 返回类型:
注意
与 NumPy 中 ediff1d 的实现不同,
jax.numpy.ediff1d()
不会在将to_end
或to_begin
转换为ary
类型时丢失精度的情况下报错。另请参阅
jax.numpy.diff()
: 计算数组沿给定轴的元素之间的 n 阶差分。jax.numpy.cumsum()
: 计算数组沿给定轴的元素的累积和。jax.numpy.gradient()
: 计算 N 维数组的梯度。
示例
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
对于
ndim > 1
的数组,差分是在展平输入数组后计算的。>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)