jax.numpy.roll

内容

jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[source]#

沿着指定轴滚动数组的元素。

JAX 实现 numpy.roll().

参数::
  • a (ArrayLike) – 输入数组。

  • shift (ArrayLike | Sequence[int]) – 要滚动指定轴的位数。如果为整数,则所有轴都以相同的量滚动。如果为元组,则分别为每个轴指定滚动量。

  • axis (int | Sequence[int] | None | None) – 要滚动的轴或轴。如果为 None,则数组将被展平、滚动,然后重新整形为其原始形状。

返回::

沿着指定轴或轴滚动元素的 a 的副本。

返回类型::

数组

另请参阅

示例

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

沿着特定轴滚动元素

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)