jax.scipy.linalg.cho_factor#

jax.scipy.linalg.cho_factor(a, lower=False, overwrite_a=False, check_finite=True)[源代码]#

基于 Cholesky 分解的线性方程求解的因式分解

scipy.linalg.cho_factor() 的 JAX 实现。此函数返回适用于 jax.scipy.linalg.cho_solve() 的结果。对于直接的 Cholesky 分解,建议使用 jax.scipy.linalg.cholesky()

参数:
  • a (ArrayLike) – 输入数组,表示(批量的)正定 Hermitian 矩阵。必须具有形状 (..., N, N)

  • lower (bool) – 如果为 True,则计算下三角 Cholesky 分解(默认值:False)。

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回:

c 是一个形状为 (..., N, N) 的数组,表示输入的下三角或上三角 Cholesky 分解;lower 是一个布尔值,用于指定这是下三角分解还是上三角分解。

返回类型:

(c, lower)

示例

一个小的实 Hermitian 正定矩阵

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

通过 cho_factor() 计算 Cholesky 因式分解,并使用它通过 cho_solve() 求解线性方程。

>>> b = jnp.array([3., 4.])
>>> cfac = jax.scipy.linalg.cho_factor(x)
>>> y = jax.scipy.linalg.cho_solve(cfac, b)
>>> y
Array([0.6666666, 1.6666666], dtype=float32)

检查结果是否一致

>>> jnp.allclose(x @ y, b)
Array(True, dtype=bool)