jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[source]#
求解线性方程组
JAX 实现
numpy.linalg.solve()
.这将为给定的
a
和b
求解(批处理)线性方程组a @ x = b
的x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (ArrayLike) – 形状为
(N,)
(对于一维右手边)或(..., N, M)
(对于批处理二维右手边)的数组。
- 返回值:
包含线性求解结果的数组。如果
b
的形状为(N,)
,则结果的形状为(..., N)
,否则形状为(..., N, M)
。- 返回类型:
另请参阅
jax.scipy.linalg.solve()
: 用于求解线性方程组的 SciPy 风格 API。jax.lax.custom_linear_solve()
: 无矩阵线性求解器。
示例
一个简单的 3x3 线性方程组
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果解决了该方程组
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)