jax.scipy.linalg.lu

内容

jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array][source]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]

计算LU分解

JAX 实现的 scipy.linalg.lu().

矩阵 A 的 LU 分解为

\[A = P L U\]

其中 P 是一个置换矩阵,L 是下三角矩阵,U 是上三角矩阵。

参数:
  • a – 形状为 (..., M, N) 的要分解的数组。

  • permute_l – 如果为 True,则置换 L 并返回 (P @ L, U)(默认值:False)

  • overwrite_a – JAX 未使用

  • check_finite – JAX 未使用

返回值:

  • P 是形状为 (..., M, M) 的置换矩阵

  • L 是形状为 (... M, K) 的下三角矩阵

  • U 是形状为 (..., K, N) 的上三角矩阵

其中 K = min(M, N)

返回类型:

如果 permute_l 为 True,则为数组元组 (P @ L, U),否则为 (P, L, U)

另请参阅

示例

3x3 矩阵的 LU 分解

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)

P 是一个置换矩阵:即每一行和每一列都只有一个 1

>>> P
Array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)

LU 是下三角和上三角矩阵

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
 [ 0.2    1.     0.   ]
 [ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
 [0.    1.2   2.6  ]
 [0.    0.    0.667]]

可以通过将这三个矩阵相乘来重建原始矩阵

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)