jax.scipy.linalg.eigh#
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: Literal[False] = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) tuple[Array, Array] [源代码]#
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, *, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array | tuple[Array, Array]
计算 Hermitian 矩阵的特征值和特征向量
scipy.linalg.eigh()
的 JAX 实现。- 参数:
a – Hermitian 输入数组,形状为
(..., N, N)
b – 可选的 Hermitian 输入,形状为
(..., N, N)
。 如果指定,则计算广义特征值问题。lower – 如果为 True(默认),则仅访问输入矩阵的下半部分。否则,仅访问上半部分。
eigvals_only – 如果为 True,则仅计算特征值。如果为 False(默认),则同时计算特征值和特征向量。
type –
如果指定了
b
,则type
指定要计算的广义特征值问题的类型。 将(λ, v)
表示为特征值、特征向量对type = 1
求解a @ v = λ * b @ v
(默认)type = 2
求解a @ b @ v = λ * v
type = 3
求解b @ a @ v = λ * v
eigvals – 一个
(low, high)
元组,指定要计算的特征值。overwrite_a – JAX 未使用。
overwrite_b – JAX 未使用。
turbo – JAX 未使用。
check_finite – JAX 未使用。
- 返回:
如果
eigvals_only
为 False,则返回数组元组(eigvals, eigvecs)
,否则返回数组eigvals
。eigvals
: 包含特征值的数组,形状为(..., N)
。eigvecs
: 包含特征向量的数组,形状为(..., N, N)
。
另请参阅
jax.numpy.linalg.eigh()
: NumPy 风格的 eigh API。jax.lax.linalg.eigh()
: XLA 风格的 eigh API。jax.numpy.linalg.eig()
: 非 Hermitian 特征值问题。jax.scipy.linalg.eigh_tridiagonal()
: 三对角特征值问题。
示例
计算一个简单的 2x2 矩阵的标准特征值分解
>>> a = jnp.array([[2., 1.], ... [1., 2.]]) >>> eigvals, eigvecs = jax.scipy.linalg.eigh(a) >>> eigvals Array([1., 3.], dtype=float32) >>> eigvecs Array([[-0.70710677, 0.70710677], [ 0.70710677, 0.70710677]], dtype=float32)
特征向量是标准正交的
>>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
解满足特征值问题
>>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals)) Array(True, dtype=bool)