jax.scipy.linalg.solve_triangular#
- jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[源代码]#
求解三角线性方程组。
scipy.linalg.solve_triangular()
的 JAX 实现。对于给定的三角矩阵
a
和向量或矩阵b
,求解(批量)线性方程组a @ x = b
中的x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。 仅会访问数组的一部分,具体取决于lower
和unit_diagonal
参数。b (ArrayLike) – 形状为
(..., N)
或(..., N, M)
的数组。lower (bool) – 如果为 True,则仅使用输入的下三角部分;如果为 False(默认值),则仅使用上三角部分。
unit_diagonal (bool) – 如果为 True,则忽略
a
的对角线元素,并假设它们为1
(默认值:False)。指定可以假设
a
的哪些属性。 可选值如下:0
或'N'
: 求解 \(Ax=b\)1
或'T'
: 求解 \(A^Tx=b\)2
或'C'
: 求解 \(A^Hx=b\)
overwrite_b (bool) – JAX 未使用。
debug (Any | None) – JAX 未使用。
check_finite (bool) – JAX 未使用。
- 返回:
包含线性方程组解的,与
b
形状相同的数组。- 返回类型:
另请参阅
jax.scipy.linalg.solve()
:求解一般线性方程组。示例
一个简单的 3x3 三角线性方程组
>>> A = jnp.array([[1., 2., 3.], ... [0., 3., 2.], ... [0., 0., 5.]]) >>> b = jnp.array([10., 8., 5.]) >>> x = jax.scipy.linalg.solve_triangular(A, b) >>> x Array([3., 2., 1.], dtype=float32)
确认结果可以求解该方程组
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)
计算转置问题
>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T') >>> x Array([10. , -4. , -3.4], dtype=float32)
确认结果可以求解该方程组
>>> jnp.allclose(A.T @ x, b) Array(True, dtype=bool)