jax.scipy.linalg.lu_solve

内容

jax.scipy.linalg.lu_solve#

jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]#

使用 LU 分解求解线性系统

JAX 实现的 scipy.linalg.lu_solve()。使用 jax.scipy.linalg.lu_factor() 的输出。

参数:
  • lu_and_piv (元组[数组, 数组类]) – (lu, piv), 为 lu_factor() 的输出。 lu 是形状为 (..., M, N) 的数组,包含其下三角中的 L 和其上三角中的 Upiv 是形状为 (..., K) 的数组,其中 K = min(M, N),它编码了枢轴。

  • b (数组类) – 线性系统的右手边。必须具有形状 (..., M)

  • trans (整数) –

    要解决的系统类型。选项是

    • 0: \(A x = b\)

    • 1: \(A^Tx = b\)

    • 2: \(A^Hx = b\)

  • overwrite_b (布尔值) – JAX 未使用

  • check_finite (布尔值) – JAX 未使用

返回值:

形状为 (..., N) 的数组,表示线性系统的解。

返回类型:

数组

示例

通过 LU 分解求解小型线性系统

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

通过 lu_factor() 计算 LU 分解,并通过 lu_solve() 用它来求解线性方程。

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

检查结果是否一致

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)