jax.scipy.linalg.eigh_tridiagonal

jax.scipy.linalg.eigh_tridiagonal#

jax.scipy.linalg.eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None)[source]#

求解对称实三对角矩阵的特征值问题

JAX 实现 scipy.linalg.eigh_tridiagonal().

参数:
  • d (ArrayLike) – 形状为 (N,) 的实值数组,指定对角元素。

  • e (ArrayLike) – 形状为 (N - 1,) 的实值数组,指定非对角元素。

  • eigvals_only (bool) – 如果为 True,则仅返回特征值(默认值为 False)。特征向量的计算尚未实现,因此 eigvals_only 必须设置为 True。

  • select (str) –

    指定要计算的特征值。支持的值为

    • 'a': 所有特征值

    • 'i': 索引为 select_range[0] <= i <= select_range[1] 的特征值

    JAX 目前未实现 select = 'v'

  • select_range (tuple[float, float] | None) – 当 select='i' 时使用的值范围。

  • tol (float | None) – 求解特征值时使用的绝对容差。

返回值:

形状为 (N,) 的特征值数组。

返回类型:

数组

参见

jax.scipy.linalg.eigh(): 一般 Hermitian 特征值求解器

示例

>>> d = jnp.array([1., 2., 3., 4.])
>>> e = jnp.array([1., 1., 1.])
>>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True)
>>> eigvals
Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32)

为了比较,我们可以构造完整的矩阵并使用 eigh() 计算相同的结果

>>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1)
>>> A
Array([[1., 1., 0., 0.],
       [1., 2., 1., 0.],
       [0., 1., 3., 1.],
       [0., 0., 1., 4.]], dtype=float32)
>>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True)
>>> jnp.allclose(eigvals, eigvals_full)
Array(True, dtype=bool)