jax.closure_convert

jax.closure_convert#

jax.closure_convert(fun, *example_args)[source]#

闭包转换实用程序,用于与高阶自定义导数一起使用。

要定义自定义导数,例如使用 jax.custom_vjp(f),目标函数 f 必须以形式参数的形式接受所有参与微分的数值。如果 f 是一个高阶函数,因为它接受一个 Python 函数 g 作为参数,那么存储在 g 的闭包中的数值将无法被自定义导数规则看到,并且涉及这些数值的 AD 尝试将会失败。解决此问题的一种方法是通过提取这些数值来转换闭包,并将它们作为显式形式参数传递到自定义导数边界。此实用程序执行该转换。更准确地说,它对专门用于 example_args 中给出的参数类型的函数 fun 进行闭包转换。

当我们在这里提到 fun 的“闭包中的数值”时,我们并不是指在定义 fun 时 Python 直接捕获的数值(例如,如果存在该属性,则为 fun.__closure__ 中的 Python 对象)。相反,我们指的是在 funexample_args 上执行过程中遇到的确定其输出的数值。例如,这可能包括在 Python 闭包中传递捕获的数组,即在 fun 调用的函数的 Python 闭包中,它们调用的函数的闭包中,等等。

函数 fun 必须是一个纯函数。

示例用法

def minimize(objective_fn, x0):
  converted_fn, aux_args = closure_convert(objective_fn, x0)
  return _minimize(converted_fn, x0, *aux_args)

@partial(custom_vjp, nondiff_argnums=(0,))
def _minimize(objective_fn, x0, *args):
  z = objective_fn(x0, *args)
  # ... find minimizer x_opt ...
  return x_opt

def fwd(objective_fn, x0, *args):
  y = _minimize(objective_fn, x0, *args)
  return y, (y, args)

def rev(objective_fn, res, g):
  y, args = res
  y_bar = g
  # ... custom reverse-mode AD ...
  return x0_bar, *args_bars

_minimize.defvjp(fwd, rev)
参数:
  • fun (Callable) – 要转换的 Python 可调用对象。必须是一个纯函数。

  • example_args – 数组、标量或(嵌套的)标准 Python 容器(元组、列表、字典、命名元组,即 pytrees),用于确定 fun 的形式参数的类型。此专门针对类型的 fun 形式是将进行闭包转换的函数。

返回值:

一对,包含 (i) 一个 Python 可调用对象,接受与 fun 相同的参数,随后是与从其闭包中提升的数值相对应的参数,以及 (ii) 从闭包中提升的数值列表。

返回类型:

tuple[Callable, list[Any]]