jax.lax.custom_root#
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]#
可微分地求解函数的根。
这是一个低级例程,主要用于 JAX 的内部使用。custom_root() 的梯度是通过隐函数定理定义的,相对于来自提供的函数
f
的闭包变量:https://en.wikipedia.org/wiki/Implicit_function_theorem- 参数:
f – 用于查找根的函数。应该接受一个参数,返回一个数组树,其结构与其输入相同。
initial_guess – 对 f 的零点的初始猜测。
solve –
用于求解 f 根的函数。应该接受两个位置参数,f 和 initial_guess,并返回一个与 initial_guess 结构相同的解,使得 func(solution) = 0。换句话说,假设以下内容为真(但不会检查)
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve –
用于求解切线系统的函数。应该接受两个位置参数,一个线性函数
g
(在根处线性化的函数f
)和一个与 initial_guess 结构相同的数组树y
,并返回一个解x
,使得g(x)=y
对于标量
y
,使用lambda g, y: y / g(1.0)
。对于向量
y
,如果y
的维度不太大,可以使用雅可比矩阵进行线性求解:lambda g, y: np.linalg.solve(jacobian(g)(y), y)
。
has_aux – bool 类型,指示
solve
函数是否返回辅助数据,例如求解器诊断信息,作为第二个参数。
- 返回值:
调用 solve(f, initial_guess) 的结果,其梯度通过隐式微分定义,假设
f(solve(f, initial_guess)) == 0
。