jax.lax.custom_linear_solve#
- jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[source]#
使用隐式定义的梯度执行无矩阵线性求解。
此函数允许通过解决方案处的隐式微分直接覆盖或定义线性求解的梯度,而不是通过求解操作进行微分。这有时可能快得多或数值上更稳定,或者通过求解操作进行微分甚至可能没有实现(例如,如果
solve
使用lax.while_loop
)。必需的不变式
x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked
- 参数:
matvec – 线性函数,需要求逆。必须可微分。
b – 方程的常数右侧。可以是任何嵌套的数组结构。
solve – 更高级的函数,用于求解线性方程的解,即,
solve(matvec, x) == x
对于所有与b
形式相同的x
。此函数不需要可微分。transpose_solve – 用于求解转置线性方程的更高级函数,即,
transpose_solve(vecmat, x) == x
,其中vecmat
是线性映射matvec
的转置(使用自动微分自动计算)。在反向模式自动微分中需要,除非symmetric=True
,在这种情况下,solve
提供默认值。symmetric – 布尔值,指示是否可以安全地假设线性映射对应于对称矩阵,即,
matvec == vecmat
。has_aux – 布尔值,指示
solve
和transpose_solve
函数是否返回辅助数据,例如求解器诊断,作为第二个参数。
- 返回值:
solve(matvec, b)
的结果,假设解x
满足线性方程matvec(x) == b
,则定义梯度。