jax.numpy.rollaxis#
- jax.numpy.rollaxis(a, axis, start=0)[source]#
将指定的轴滚动到给定位置。
numpy.rollaxis()
的 JAX 实现。此函数的存在是为了与 NumPy 保持兼容性,但在大多数情况下,可以使用更新的
jax.numpy.moveaxis()
,因为其参数的含义更直观。- 参数:
- 返回值:
具有已滚动轴的
a
的副本。- 返回类型:
备注
与
numpy.rollaxis()
不同,jax.numpy.rollaxis()
将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会对性能造成影响。另请参阅
jax.numpy.moveaxis()
:比rollaxis
语义更清晰的新 API;在大多数情况下,应该优先选择它而不是rollaxis
。jax.numpy.swapaxes()
:交换两个轴。jax.numpy.transpose()
:轴的一般排列。
示例
>>> a = jnp.ones((2, 3, 4, 5))
将轴 2 滚动到数组的开头
>>> jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
将轴 1 滚动到数组的末尾
>>> jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
使用
moveaxis()
等效于以上两个操作>>> jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)