jax.numpy.permute_dims

jax.numpy.permute_dims#

jax.numpy.permute_dims(a, /, axes)[source]#

置换数组的轴/维度。

JAX 实现 array_api.permute_dims()

参数:
  • a (ArrayLike) – 输入数组

  • axes (tuple[int, ...]) – 范围在 [0, a.ndim) 内的整数元组,指定轴置换。

返回值:

一个轴被置换的 a 副本。

返回值类型:

Array

示例

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.permute_dims(a, (1, 0))
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)