jax.scipy.linalg.lu_solve#
- jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]#
使用 LU 分解求解线性系统
JAX 实现的
scipy.linalg.lu_solve()
。使用jax.scipy.linalg.lu_factor()
的输出。- 参数:
lu_and_piv (元组[数组, 数组类]) –
(lu, piv)
, 为lu_factor()
的输出。lu
是形状为(..., M, N)
的数组,包含其下三角中的L
和其上三角中的U
。piv
是形状为(..., K)
的数组,其中K = min(M, N)
,它编码了枢轴。b (数组类) – 线性系统的右手边。必须具有形状
(..., M)
trans (整数) –
要解决的系统类型。选项是
0
: \(A x = b\)1
: \(A^Tx = b\)2
: \(A^Hx = b\)
overwrite_b (布尔值) – JAX 未使用
check_finite (布尔值) – JAX 未使用
- 返回值:
形状为
(..., N)
的数组,表示线性系统的解。- 返回类型:
示例
通过 LU 分解求解小型线性系统
>>> a = jnp.array([[2., 1.], ... [1., 2.]])
通过
lu_factor()
计算 LU 分解,并通过lu_solve()
用它来求解线性方程。>>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32)
检查结果是否一致
>>> jnp.allclose(a @ y, b) Array(True, dtype=bool)