jax.numpy.permute_dims#
- jax.numpy.permute_dims(a, /, axes)[source]#
置换数组的轴/维度。
JAX 实现
array_api.permute_dims()
。- 参数:
- 返回值:
一个轴被置换的
a
副本。- 返回值类型:
示例
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)