jax.scipy.linalg.lu#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array] [源代码]#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
计算 LU 分解
scipy.linalg.lu()
的 JAX 实现。矩阵 A 的 LU 分解为
\[A = P L U\]其中 P 是置换矩阵,L 是下三角矩阵,U 是上三角矩阵。
- 参数:
a – 要分解的形状为
(..., M, N)
的数组。permute_l – 如果为 True,则置换
L
并返回(P @ L, U)
(默认值:False)overwrite_a – JAX 不使用。
check_finite – JAX 不使用。
- 返回:
P
是形状为(..., M, M)
的置换矩阵。L
是形状为(... M, K)
的下三角矩阵。U
是形状为(..., K, N)
的上三角矩阵。
其中
K = min(M, N)
- 返回类型:
如果
permute_l
为 True,则返回数组的元组(P @ L, U)
,否则返回(P, L, U)
。
另请参阅
jax.numpy.linalg.lu()
:用于 LU 分解的 NumPy 风格 API。jax.lax.linalg.lu()
:用于 LU 分解的 XLA 风格 API。jax.scipy.linalg.lu_solve()
:基于 LU 的线性求解器。
示例
一个 3x3 矩阵的 LU 分解
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a)
P
是一个置换矩阵:即每一行和每一列都有一个单独的1
>>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32)
L
和U
分别是下三角矩阵和上三角矩阵>>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]]
原始矩阵可以通过将这三者相乘来重建
>>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)