jax.lax.custom_root#

jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[源代码]#

可微地求解函数的根。

这是一个底层例程,主要用于 JAX 内部。`custom_root()` 的梯度是根据所提供的函数 `f` 中闭包变量定义的,通过隐函数定理:https://en.wikipedia.org/wiki/Implicit_function_theorem

参数:
  • f – 要寻找根的函数。应接受一个参数,返回一个与输入结构相同的数组树。

  • initial_guess – f 的零点的初始猜测。

  • solve

    求解 f 的根的函数。应接受两个位置参数,f 和 initial_guess,并返回一个与 initial_guess 结构相同的解,使得 func(solution) = 0。换句话说,假设以下为真(但不检查)

    solution = solve(f, initial_guess)
    error = f(solution)
    assert all(error == 0)
    

  • tangent_solve

    求解切线系统的函数。应接受两个位置参数,一个线性函数 `g`(函数 `f` 在其根处线性化)和一个与 initial_guess 结构相同的数组树 `y`,并返回一个解 `x`,使得 `g(x)=y`

    • 对于标量 `y`,使用 `lambda g, y: y / g(1.0)`。

    • 对于向量 `y`,如果 `y` 的维度不太大,可以使用带雅可比矩阵的线性求解:`lambda g, y: np.linalg.solve(jacobian(g)(y), y)`。

  • has_aux – 布尔值,指示 `solve` 函数是否返回辅助数据,如求解器诊断信息作为第二个参数。

返回值:

调用 solve(f, initial_guess) 的结果,梯度通过隐式微分定义,假设 `f(solve(f, initial_guess)) == 0`。