jax.numpy.linalg.svd#
- jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) SVDResult [source]#
- jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) SVDResult
- jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array
- jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array
- jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array | SVDResult
计算奇异值分解。
JAX 实现的
numpy.linalg.svd()
,以jax.lax.linalg.svd()
为基础。矩阵 A 的 SVD 表示为
\[A = U\Sigma V^H\]\(U\) 包含左奇异向量,满足 \(U^HU=I\)
\(V\) 包含右奇异向量,满足 \(V^HV=I\)
\(\Sigma\) 是一个奇异值的对角矩阵。
- 参数:
a – 输入数组,形状为
(..., N, M)
full_matrices – 如果为 True(默认),则计算完整的矩阵;即
u
和vh
的形状为(..., N, N)
和(..., M, M)
。如果为 False,则形状为(..., N, K)
和(..., K, M)
,其中K = min(N, M)
。compute_uv – 如果为 True(默认),则返回完整的 SVD
(u, s, vh)
。如果为 False,则仅返回奇异值s
。hermitian – 如果为 True,则假设矩阵为厄米矩阵,这允许更有效的实现(默认=False)
subset_by_index –(仅限 TPU)可选的 2 元组 [start, end],指示要计算的奇异值的索引范围。例如,如果
[n-2, n]
则svd
计算两个最大奇异值及其奇异向量。仅与full_matrices=False
兼容。
- 返回值:
如果
compute_uv
为 True,则为数组(u, s, vh)
的元组,否则为数组s
。u
: 左奇异向量,形状为(..., N, N)
(如果full_matrices
为 True)或(..., N, K)
(否则)。s
: 奇异值,形状为(..., K)
vh
: 右奇异向量的共轭转置,形状为(..., M, M)
(如果full_matrices
为 True)或(..., K, M)
(否则)。
其中
K = min(N, M)
。
另请参阅
jax.scipy.linalg.svd()
: SciPy 风格的 SVD APIjax.lax.linalg.svd()
: XLA 风格的 SVD API
示例
考虑一个小实值数组的 SVD
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
奇异向量位于
u
和v = vt.T
的列中。这些向量是正交的,可以通过将矩阵乘积与单位矩阵进行比较来证明>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
给定 SVD,
x
可以通过矩阵乘法重建>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)