jax.numpy.matrix_transpose

jax.numpy.matrix_transpose#

jax.numpy.matrix_transpose(x, /)[source]#

转置数组的最后两个维度。

JAX 实现 numpy.matrix_transpose(),使用 jax.lax.transpose() 实现。

参数:

x (ArrayLike) – 输入数组,必须具有 x.ndim >= 2

返回值:

数组的矩阵转置副本。

返回值类型:

数组

另请参阅

注意

numpy.matrix_transpose() 不同,jax.numpy.matrix_transpose() 将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会影响性能。

示例

这是一个 2x2x2 矩阵,表示批处理的 2x2 矩阵

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

为了方便起见,您可以通过 mT 属性对 jax.Array 执行相同的转置操作。

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)