jax.closure_convert#
- jax.closure_convert(fun, *example_args)[source]#
闭包转换实用程序,用于与高阶自定义导数一起使用。
要定义自定义导数,例如使用
jax.custom_vjp(f)
,目标函数f
必须以形式参数的形式接受所有参与微分的数值。如果f
是一个高阶函数,因为它接受一个 Python 函数g
作为参数,那么存储在g
的闭包中的数值将无法被自定义导数规则看到,并且涉及这些数值的 AD 尝试将会失败。解决此问题的一种方法是通过提取这些数值来转换闭包,并将它们作为显式形式参数传递到自定义导数边界。此实用程序执行该转换。更准确地说,它对专门用于example_args
中给出的参数类型的函数fun
进行闭包转换。当我们在这里提到
fun
的“闭包中的数值”时,我们并不是指在定义fun
时 Python 直接捕获的数值(例如,如果存在该属性,则为fun.__closure__
中的 Python 对象)。相反,我们指的是在fun
在example_args
上执行过程中遇到的确定其输出的数值。例如,这可能包括在 Python 闭包中传递捕获的数组,即在fun
调用的函数的 Python 闭包中,它们调用的函数的闭包中,等等。函数
fun
必须是一个纯函数。示例用法
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reverse-mode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)