jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[源代码]#

沿指定轴滚动数组的元素。

JAX实现的 numpy.roll()

参数:
  • a (ArrayLike) – 输入数组。

  • shift (ArrayLike | Sequence[int]) – 指定轴移动的位置数。如果是一个整数,则所有轴移动相同的数量。如果是一个元组,则为每个轴单独指定移动量。

  • axis (int | Sequence[int] | None | None) – 要滚动的轴或多个轴。如果为 None,则数组被展平,移动,然后重塑为原始形状。

返回:

沿指定轴或多个轴滚动元素的 a 的副本。

返回类型:

Array

另请参阅

示例

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

沿特定轴滚动元素

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)