jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array][源代码]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]

计算 LU 分解

scipy.linalg.lu() 的 JAX 实现。

矩阵 A 的 LU 分解为

\[A = P L U\]

其中 P 是置换矩阵,L 是下三角矩阵,U 是上三角矩阵。

参数:
  • a – 要分解的形状为 (..., M, N) 的数组。

  • permute_l – 如果为 True,则置换 L 并返回 (P @ L, U) (默认值:False)

  • overwrite_a – JAX 不使用。

  • check_finite – JAX 不使用。

返回:

  • P 是形状为 (..., M, M) 的置换矩阵。

  • L 是形状为 (... M, K) 的下三角矩阵。

  • U 是形状为 (..., K, N) 的上三角矩阵。

其中 K = min(M, N)

返回类型:

如果 permute_l 为 True,则返回数组的元组 (P @ L, U),否则返回 (P, L, U)

另请参阅

示例

一个 3x3 矩阵的 LU 分解

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)

P 是一个置换矩阵:即每一行和每一列都有一个单独的 1

>>> P
Array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)

LU 分别是下三角矩阵和上三角矩阵

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
 [ 0.2    1.     0.   ]
 [ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
 [0.    1.2   2.6  ]
 [0.    0.    0.667]]

原始矩阵可以通过将这三者相乘来重建

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)