jax.numpy.linalg.qr#
- jax.numpy.linalg.qr(a, mode='reduced')[源代码]#
计算数组的 QR 分解
JAX 实现的
numpy.linalg.qr()
。矩阵 A 的 QR 分解由下式给出:
\[A = QR\]其中 Q 是酉矩阵(即 \(Q^HQ=I\)),而 R 是上三角矩阵。
- 参数:
a (ArrayLike) – 形状为 (…, M, N) 的数组
mode (str) –
计算模式。支持的值为:
"reduced"
(默认值):返回形状为(..., M, K)
的 Q 和形状为(..., K, N)
的 R,其中K = min(M, N)
。"complete"
:返回形状为(..., M, M)
的 Q 和形状为(..., M, N)
的 R。"raw"
:返回形状为(..., M, N)
和(..., K)
的 lapack 内部表示。"r"
:仅返回 R。
- 返回:
一个元组
(Q, R)
(如果mode
不是"r"
),否则返回一个数组R
,其中Q
是形状为(..., M, K)
(如果mode
是"reduced"
)或(..., M, M)
(如果mode
是"complete"
)的正交矩阵。R
是形状为(..., M, N)
(如果mode
是"r"
或"complete"
)或(..., K, N)
(如果mode
是"reduced"
)的上三角矩阵。
其中
K = min(M, N)
。- 返回类型:
Array | QRResult
另请参阅
jax.scipy.linalg.qr()
:SciPy 风格的 QR 分解 APIjax.lax.linalg.qr()
:XLA 风格的 QR 分解 API
示例
计算矩阵的 QR 分解
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
检查
Q
是否是标准正交的>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
重建输入
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)