jax.numpy.rollaxis

内容

jax.numpy.rollaxis#

jax.numpy.rollaxis(a, axis, start=0)[source]#

将指定的轴滚动到给定位置。

numpy.rollaxis() 的 JAX 实现。

此函数的存在是为了与 NumPy 保持兼容性,但在大多数情况下,可以使用更新的 jax.numpy.moveaxis(),因为其参数的含义更直观。

参数:
  • a (ArrayLike) – 输入数组。

  • axis (int) – 要向前滚动的轴的索引。

  • start (int) – 轴将滚动到的索引(默认值 = 0)。在规范化负轴后,如果 start <= axis,则轴滚动到 start 索引;如果 start > axis,则轴滚动到 start 之前的那个位置。

返回值:

具有已滚动轴的 a 的副本。

返回类型:

数组

备注

numpy.rollaxis() 不同,jax.numpy.rollaxis() 将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会对性能造成影响。

另请参阅

示例

>>> a = jnp.ones((2, 3, 4, 5))

将轴 2 滚动到数组的开头

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

将轴 1 滚动到数组的末尾

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

使用 moveaxis() 等效于以上两个操作

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)