jax.numpy.moveaxis#
- jax.numpy.moveaxis(a, source, destination)[源代码]#
将数组轴移动到新位置
JAX 实现的
numpy.moveaxis()
,使用jax.lax.transpose()
实现。- 参数:
- 返回:
将轴从
source
移动到destination
后的a
的副本。- 返回类型:
注释
与
numpy.moveaxis()
不同,jax.numpy.moveaxis()
将返回输入数组的副本,而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉此类副本,因此这在实践中不会影响性能。另请参阅
jax.numpy.swapaxes()
: 交换两个轴。jax.numpy.rollaxis()
: 用于移动轴的较旧 API。jax.numpy.transpose()
: 通用轴置换。
示例
>>> a = jnp.ones((2, 3, 4, 5))
将轴
1
移动到数组的末尾>>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
将最后一个轴移动到位置 1
>>> jnp.moveaxis(a, -1, 1).shape (2, 5, 3, 4)
移动多个轴
>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape (4, 5, 3, 2)
这也可以通过
transpose()
完成>>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2)