jax.scipy.sparse.linalg.cg#
- jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[source]#
使用共轭梯度迭代求解
Ax = b
。JAX 的
cg
的数值结果应该与 SciPy 的cg
完全匹配(在数值精度范围内),但要注意接口略有不同:您需要提供线性算子A
作为函数,而不是稀疏矩阵或LinearOperator
。cg
的导数是通过隐式微分实现的,使用另一个cg
求解,而不是通过求解器进行微分。它们只有在两个求解都收敛时才是准确的。- 参数:
A (ndarray, function 或 支持矩阵乘法的对象) – 2D 数组或函数,在被调用时计算线性映射(矩阵-向量乘积)
Ax
,类似于A(x)
或A @ x
。A
必须表示一个埃尔米特正定矩阵,并且必须返回与其参数具有相同结构和形状的数组。b (array 或 tree of arrays) – 线性方程组的右侧,表示单个向量。可以存储为数组或具有任意形状的数组的 Python 容器。
x0 (array 或 tree of arrays) – 解决方案的初始猜测。必须与
b
具有相同的结构。tol (float, optional) – 收敛容差,
norm(residual) <= max(tol*norm(b), atol)
。我们没有实现 SciPy 的“传统”行为,因此除非您显式地将atol
传递给 SciPy 的cg
,否则 JAX 的容差将与 SciPy 不同。atol (float, optional) – 收敛容差,
norm(residual) <= max(tol*norm(b), atol)
。我们没有实现 SciPy 的“传统”行为,因此除非您显式地将atol
传递给 SciPy 的cg
,否则 JAX 的容差将与 SciPy 不同。maxiter (integer) – 最大迭代次数。即使没有达到指定的容差,迭代也会在 maxiter 步后停止。
M (ndarray, function 或 支持矩阵乘法的对象) – A 的预处理器。预处理器应该近似于 A 的逆。有效的预处理会显著提高收敛速度,这意味着达到给定误差容差所需的迭代次数更少。
- 返回值:
x (array 或 tree of arrays) – 收敛的解。与
b
具有相同的结构。info (None) – 收敛信息的占位符。将来,JAX 会像 SciPy 一样报告未收敛时的迭代次数。