jax.numpy.matrix_transpose#
- jax.numpy.matrix_transpose(x, /)[源代码]#
转置数组的最后两个维度。
JAX 实现的
numpy.matrix_transpose()
,使用jax.lax.transpose()
实现。- 参数:
x (类数组) – 输入数组,必须具有
x.ndim >= 2
- 返回:
数组的矩阵转置副本。
- 返回类型:
参见
jax.Array.mT
:通过Array()
属性访问的相同操作。jax.numpy.transpose()
:一般多轴转置
注意
与
numpy.matrix_transpose()
不同,jax.numpy.matrix_transpose()
将返回输入数组的副本,而不是视图。但是,在 JIT 编译下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。示例
这是一个 2x2x2 矩阵,表示一个批次的 2x2 矩阵
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
为了方便起见,您可以通过
mT
属性对jax.Array
执行相同的转置>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)