jax.numpy.linalg.svd#

jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)[源代码]#

计算奇异值分解。

JAX 实现的 numpy.linalg.svd(),基于 jax.lax.linalg.svd() 实现。

矩阵 A 的 SVD 由下式给出

\[A = U\Sigma V^H\]
  • \(U\) 包含左奇异向量,并满足 \(U^HU=I\)

  • \(V\) 包含右奇异向量,并满足 \(V^HV=I\)

  • \(\Sigma\) 是一个奇异值的对角矩阵。

参数:
  • a (ArrayLike) – 输入数组,形状为 (..., N, M)

  • full_matrices (bool) – 如果为 True (默认),则计算完整矩阵;即 uvh 的形状分别为 (..., N, N)(..., M, M)。如果为 False,则形状分别为 (..., N, K)(..., K, M),其中 K = min(N, M)

  • compute_uv (bool) – 如果为 True (默认),则返回完整 SVD (u, s, vh)。如果为 False,则仅返回奇异值 s

  • hermitian (bool) – 如果为 True,则假设矩阵是厄米矩阵,这可以实现更高效的实现(默认值=False)

  • subset_by_index (tuple[int, int] | None) – (仅限 TPU)可选的 2 元组 [start, end],表示要计算的奇异值的索引范围。例如,如果 [n-2, n],则 svd 计算两个最大的奇异值及其奇异向量。仅与 full_matrices=False 兼容。

返回值:

如果 compute_uv 为 True,则返回数组元组 (u, s, vh),否则返回数组 s

  • u: 如果 full_matrices 为 True,则左奇异向量的形状为 (..., N, N),否则为 (..., N, K)

  • s: 奇异值的形状为 (..., K)

  • vh: 如果 full_matrices 为 True,则共轭转置右奇异向量的形状为 (..., M, M),否则为 (..., K, M)

其中 K = min(N, M)

返回类型:

Array | SVDResult

另请参阅

示例

考虑一个小型实值数组的 SVD

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jnp.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

奇异向量位于 uv = vt.T 的列中。这些向量是正交的,可以通过将矩阵乘积与单位矩阵进行比较来证明

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

给定 SVD,可以通过矩阵乘法重建 x

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)