jax.numpy.permute_dims#
- jax.numpy.permute_dims(a, /, axes)[源代码]#
置换数组的轴/维度。
array_api.permute_dims()
的 JAX 实现。- 参数:
- 返回:
一个轴已置换的
a
的副本。- 返回类型:
示例
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)