jax.numpy.moveaxis

内容

jax.numpy.moveaxis#

jax.numpy.moveaxis(a, source, destination)[source]#

将数组轴移动到新位置

JAX 实现的 numpy.moveaxis(),使用 jax.lax.transpose() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • source (int | Sequence[int]) – 要移动的轴的索引或索引序列。

  • 目标 (int | Sequence[int]) – 轴目标的索引或索引序列

返回值:

将轴从 source 移动到 destinationa 的副本。

返回类型:

数组

备注

numpy.moveaxis() 不同,jax.numpy.moveaxis() 将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会对性能产生影响。

另请参见

示例

>>> a = jnp.ones((2, 3, 4, 5))

将轴 1 移动到数组末尾

>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)

将最后一个轴移动到位置 1

>>> jnp.moveaxis(a, -1, 1).shape
(2, 5, 3, 4)

移动多个轴

>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape
(4, 5, 3, 2)

这也可以通过 transpose() 完成

>>> a.transpose(2, 3, 1, 0).shape
(4, 5, 3, 2)