jax.scipy.linalg.solve_triangular#

jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[源代码]#

求解三角线性方程组。

scipy.linalg.solve_triangular() 的 JAX 实现。

对于给定的三角矩阵 a 和向量或矩阵 b,求解(批量)线性方程组 a @ x = b 中的 x

参数
  • a (ArrayLike) – 形状为 (..., N, N) 的数组。 仅会访问数组的一部分,具体取决于 lowerunit_diagonal 参数。

  • b (ArrayLike) – 形状为 (..., N)(..., N, M) 的数组。

  • lower (bool) – 如果为 True,则仅使用输入的下三角部分;如果为 False(默认值),则仅使用上三角部分。

  • unit_diagonal (bool) – 如果为 True,则忽略 a 的对角线元素,并假设它们为 1(默认值:False)。

  • trans (int | str) –

    指定可以假设 a 的哪些属性。 可选值如下:

    • 0'N': 求解 \(Ax=b\)

    • 1'T': 求解 \(A^Tx=b\)

    • 2'C': 求解 \(A^Hx=b\)

  • overwrite_b (bool) – JAX 未使用。

  • debug (Any | None) – JAX 未使用。

  • check_finite (bool) – JAX 未使用。

返回

包含线性方程组解的,与 b 形状相同的数组。

返回类型

Array

另请参阅

jax.scipy.linalg.solve():求解一般线性方程组。

示例

一个简单的 3x3 三角线性方程组

>>> A = jnp.array([[1., 2., 3.],
...                [0., 3., 2.],
...                [0., 0., 5.]])
>>> b = jnp.array([10., 8., 5.])
>>> x = jax.scipy.linalg.solve_triangular(A, b)
>>> x
Array([3., 2., 1.], dtype=float32)

确认结果可以求解该方程组

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)

计算转置问题

>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T')
>>> x
Array([10. , -4. , -3.4], dtype=float32)

确认结果可以求解该方程组

>>> jnp.allclose(A.T @ x, b)
Array(True, dtype=bool)