jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[源代码]#
求解线性方程组
JAX 实现的
numpy.linalg.solve()
。此函数求解 (批量的) 线性方程组
a @ x = b
,对于给定的a
和b
,求解x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (ArrayLike) – 形状为
(N,)
(对于一维右侧向量) 或(..., N, M)
(对于批量二维右侧向量) 的数组。
- 返回:
一个包含线性求解结果的数组。如果
b
的形状为(N,)
,则结果的形状为(..., N)
;否则,结果的形状为(..., N, M)
。- 返回类型:
另请参阅
jax.scipy.linalg.solve()
: 用于求解线性系统的 SciPy 风格 API。jax.lax.custom_linear_solve()
: 无矩阵线性求解器。
示例
一个简单的 3x3 线性系统
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果确实求解了该系统
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)