jax.numpy.permute_dims#
- jax.numpy.permute_dims(a, /, axes)[源代码]#
置换数组的轴/维度。
array_api.permute_dims()
的 JAX 实现。示例
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
置换数组的轴/维度。
array_api.permute_dims()
的 JAX 实现。
示例
>>> a = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.permute_dims(a, (1, 0))
Array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)