jax.numpy.ediff1d#
- jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[源代码]#
计算展平数组的元素之间的差异。
numpy.ediff1d()
的 JAX 实现。- 参数:
ary (ArrayLike) – 输入数组或标量。
to_end (ArrayLike | None) – 标量或数组,可选,默认=None。指定要附加到结果数组的数字。
to_begin (ArrayLike | None) – 标量或数组,可选,默认=None。指定要添加到结果数组开头的数字。
- 返回:
一个数组,包含输入数组元素之间的差异。
- 返回类型:
注意
与 NumPy 的 ediff1d 实现不同,如果将
to_end
或to_begin
转换为ary
的类型时发生精度损失,jax.numpy.ediff1d()
不会发出错误。另请参阅
jax.numpy.diff()
:计算沿给定轴的数组元素之间的 n 阶差分。jax.numpy.cumsum()
:计算沿给定轴的数组元素的累积和。jax.numpy.gradient()
:计算 N 维数组的梯度。
示例
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
对于
ndim > 1
的数组,差异是在展平输入数组后计算的。>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)