jax.numpy.linalg.qr#
- jax.numpy.linalg.qr(a, mode='reduced')[源代码]#
计算数组的 QR 分解
numpy.linalg.qr()
的 JAX 实现。矩阵 A 的 QR 分解由下式给出
\[A = QR\]其中 Q 是一个酉矩阵(即 \(Q^HQ=I\)),而 R 是一个上三角矩阵。
- 参数:
a (ArrayLike) – 形状为 (…, M, N) 的数组
mode (str) –
计算模式。支持的值为
"reduced"
(默认值): 返回形状为(..., M, K)
的 Q 和形状为(..., K, N)
的 R,其中K = min(M, N)
。"complete"
: 返回形状为(..., M, M)
的 Q 和形状为(..., M, N)
的 R。"raw"
: 返回形状为(..., M, N)
和(..., K)
的 lapack 内部表示。"r"
: 仅返回 R。
- 返回值:
如果
mode
不是"r"
,则返回一个元组(Q, R)
,否则返回一个数组R
,其中Q
是形状为(..., M, K)
(如果mode
是"reduced"
) 或(..., M, M)
(如果mode
是"complete"
) 的正交矩阵。R
是形状为(..., M, N)
(如果mode
是"r"
或"complete"
) 或(..., K, N)
(如果mode
是"reduced"
) 的上三角矩阵。
其中
K = min(M, N)
。- 返回类型:
Array | QRResult
另请参阅
jax.scipy.linalg.qr()
: SciPy 风格的 QR 分解 APIjax.lax.linalg.qr()
: XLA 风格的 QR 分解 API
示例
计算矩阵的 QR 分解
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
检查
Q
是否是正交矩阵>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
重构输入
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)