jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[源代码]#
求解线性方程组
numpy.linalg.solve()
的 JAX 实现。对于给定的
a
和b
,求解(批处理的)线性方程组a @ x = b
的x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (类数组) – 形状为
(N,)
的数组(对于一维右侧)或(..., N, M)
的数组(对于批量二维右侧)。
- 返回:
一个包含线性解结果的数组。如果
b
的形状为(N,)
,则结果的形状为(..., N)
;否则,结果的形状为(..., N, M)
。- 返回类型:
另请参阅
jax.scipy.linalg.solve()
:用于求解线性系统的 SciPy 风格 API。jax.lax.custom_linear_solve()
:无矩阵线性求解器。
示例
一个简单的 3x3 线性系统
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果确实是系统的解
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)