jax.scipy.linalg.lu#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array] [source]#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
计算LU分解
JAX 实现的
scipy.linalg.lu()
.矩阵 A 的 LU 分解为
\[A = P L U\]其中 P 是一个置换矩阵,L 是下三角矩阵,U 是上三角矩阵。
- 参数:
a – 形状为
(..., M, N)
的要分解的数组。permute_l – 如果为 True,则置换
L
并返回(P @ L, U)
(默认值:False)overwrite_a – JAX 未使用
check_finite – JAX 未使用
- 返回值:
P
是形状为(..., M, M)
的置换矩阵L
是形状为(... M, K)
的下三角矩阵U
是形状为(..., K, N)
的上三角矩阵
其中
K = min(M, N)
- 返回类型:
如果
permute_l
为 True,则为数组元组(P @ L, U)
,否则为(P, L, U)
另请参阅
jax.numpy.linalg.lu()
:NumPy 样式的 LU 分解 API。jax.lax.linalg.lu()
:XLA 样式的 LU 分解 API。jax.scipy.linalg.lu_solve()
:基于 LU 的线性求解器。
示例
3x3 矩阵的 LU 分解
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a)
P
是一个置换矩阵:即每一行和每一列都只有一个1
>>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32)
L
和U
是下三角和上三角矩阵>>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]]
可以通过将这三个矩阵相乘来重建原始矩阵
>>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)