jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[源代码]#
沿指定轴滚动数组的元素。
numpy.roll()
的 JAX 实现。- 参数:
- 返回:
一个
a
的副本,其元素沿指定的轴或多个轴滚动。- 返回类型:
另请参阅
jax.numpy.rollaxis()
:将指定的轴滚动到给定的位置。
示例
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
沿特定轴滚动元素
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)