jax.numpy.linalg.qr#

jax.numpy.linalg.qr(a, mode='reduced')[源代码]#

计算数组的 QR 分解

numpy.linalg.qr() 的 JAX 实现。

矩阵 A 的 QR 分解由下式给出

\[A = QR\]

其中 Q 是一个酉矩阵(即 \(Q^HQ=I\)),而 R 是一个上三角矩阵。

参数:
  • a (ArrayLike) – 形状为 (…, M, N) 的数组

  • mode (str) –

    计算模式。支持的值为

    • "reduced" (默认值): 返回形状为 (..., M, K)Q 和形状为 (..., K, N)R,其中 K = min(M, N)

    • "complete": 返回形状为 (..., M, M)Q 和形状为 (..., M, N)R

    • "raw": 返回形状为 (..., M, N)(..., K) 的 lapack 内部表示。

    • "r": 仅返回 R

返回值:

如果 mode 不是 "r",则返回一个元组 (Q, R),否则返回一个数组 R,其中

  • Q 是形状为 (..., M, K) (如果 mode"reduced") 或 (..., M, M) (如果 mode"complete") 的正交矩阵。

  • R 是形状为 (..., M, N) (如果 mode"r""complete") 或 (..., K, N) (如果 mode"reduced") 的上三角矩阵。

其中 K = min(M, N)

返回类型:

Array | QRResult

另请参阅

示例

计算矩阵的 QR 分解

>>> a = jnp.array([[1., 2., 3., 4.],
...                [5., 4., 2., 1.],
...                [6., 3., 1., 5.]])
>>> Q, R = jnp.linalg.qr(a)
>>> Q  
Array([[-0.12700021, -0.7581426 , -0.6396022 ],
       [-0.63500065, -0.43322435,  0.63960224],
       [-0.7620008 ,  0.48737738, -0.42640156]], dtype=float32)
>>> R  
Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ],
       [ 0.       , -1.7870499, -2.6534991, -1.028908 ],
       [ 0.       ,  0.       , -1.0660033, -4.050814 ]], dtype=float32)

检查 Q 是否是正交矩阵

>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

重构输入

>>> jnp.allclose(Q @ R, a)
Array(True, dtype=bool)