jax.scipy.sparse.linalg.cg#

jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[源代码]#

使用共轭梯度迭代法求解 Ax = b

JAX 的 cg 的数值计算应该与 SciPy 的 cg 完全匹配(直到数值精度),但请注意接口略有不同:您需要将线性算子 A 作为函数提供,而不是稀疏矩阵或 LinearOperator

cg 的导数通过隐式微分和另一个 cg 求解来实现,而不是通过微分通过求解器。只有当两个求解都收敛时,它们才是准确的。

参数:
  • A (ndarray函数与矩阵乘法兼容的对象) – 2D 数组或函数,当像 A(x)A @ x 这样调用时,计算线性映射(矩阵-向量积) AxA 必须表示一个 Hermitian、正定矩阵,并且必须返回与其参数具有相同结构和形状的数组。

  • b (数组数组的树) – 表示单个向量的线性系统的右侧。 可以存储为具有任何形状的数组或数组的 Python 容器。

  • x0 (数组数组的树) – 解的起始猜测。 必须与 b 具有相同的结构。

  • tol (float, 可选) – 收敛的容差,norm(residual) <= max(tol*norm(b), atol)。我们没有实现 SciPy 的“遗留”行为,因此 JAX 的容差将与 SciPy 不同,除非您显式地将 atol 传递给 SciPy 的 cg

  • atol (float, 可选) – 收敛的容差,norm(residual) <= max(tol*norm(b), atol)。我们没有实现 SciPy 的“遗留”行为,因此 JAX 的容差将与 SciPy 不同,除非您显式地将 atol 传递给 SciPy 的 cg

  • maxiter (integer) – 最大迭代次数。即使没有达到指定的容差,迭代也会在 maxiter 步后停止。

  • M (ndarray函数与矩阵乘法兼容的对象) – A 的预处理器。预处理器应近似 A 的逆矩阵。有效的预处理可以显著提高收敛速度,这意味着达到给定的误差容差所需的迭代次数更少。

返回:

  • x (数组或数组的树) – 收敛的解。与 b 具有相同的结构。

  • info (None) – 收敛信息的占位符。将来,当未达到收敛时,JAX 将报告迭代次数,如 SciPy。