jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[source]#
沿着指定轴滚动数组的元素。
JAX 实现
numpy.roll()
.- 参数::
- 返回::
沿着指定轴或轴滚动元素的
a
的副本。- 返回类型::
另请参阅
jax.numpy.rollaxis()
: 将指定轴滚动到给定位置。
示例
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
沿着特定轴滚动元素
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)