jax.numpy.linalg.matrix_transpose

jax.numpy.linalg.matrix_transpose#

jax.numpy.linalg.matrix_transpose(x, /)[source]#

转置矩阵或矩阵堆栈。

JAX 实现 numpy.linalg.matrix_transpose().

参数:

x (ArrayLike) – 形状为 (..., M, N) 的数组

返回值:

形状为 (..., N, M) 的数组,包含 x 的矩阵转置。

返回类型:

数组

另请参阅

jax.numpy.transpose(): 更通用的转置操作。

示例

单个矩阵的转置

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.linalg.matrix_transpose(x)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

矩阵堆栈的转置

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.linalg.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

为了方便起见,可以使用 JAX 数组对象的 mT 属性进行相同的计算。

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)