custom_vjp 和 nondiff_argnums 更新指南

custom_vjpnondiff_argnums 更新指南#

mattjj@ 2020 年 10 月 14 日

本文档假设您熟悉 jax.custom_vjp,如 JAX 可变换 Python 函数的自定义导数规则 笔记本中所述。

需要更新的内容#

JAX PR #4008 之后,传递到 custom_vjp 函数的 nondiff_argnums 的参数不能是 Tracer(或 Tracer 的容器),这意味着要允许任意可变换的代码,nondiff_argnums 不应该用于数组值参数。相反,nondiff_argnums 应该仅用于非数组值,例如 Python 可调用对象或形状元组或字符串。

在以前使用 nondiff_argnums 用于数组值的地方,我们应该将它们作为常规参数传递。在 bwd 规则中,我们需要为它们生成值,但我们可以只生成 None 值来指示没有相应的梯度值。

例如,以下是如何编写 clip_gradient 的**旧**方法,当 hi 和/或 lo 是来自 JAX 变换的 Tracer 时,它将不起作用。

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, None  # no residual values to save

def clip_gradient_bwd(lo, hi, _, g):
  return (jnp.clip(g, lo, hi),)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

以下是如何编写支持任意变换的**新**方法,它很棒。

import jax

@jax.custom_vjp  # no nondiff_argnums!
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save lo and hi values as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # return None for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

如果您使用旧方法而不是新方法,在任何可能出错的情况下,您都会收到一个响亮的错误(即当有 Tracer 传递到 nondiff_argnums 参数时)。

以下是一个我们实际需要 nondiff_argnumscustom_vjp 的情况。

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
  return f(x)

def skip_app_fwd(f, x):
  return skip_app(f, x), None

def skip_app_bwd(f, _, g):
  return (g,)

skip_app.defvjp(skip_app_fwd, skip_app_bwd)

解释#

Tracer 传递到 nondiff_argnums 参数一直存在 bug。虽然有些情况能够正常工作,但其他情况会导致复杂且令人困惑的错误消息。

该 bug 的本质是 nondiff_argnums 的实现方式非常像词法闭包。但当时词法闭包在 Tracer 上并非旨在与 custom_jvp/custom_vjp 一起使用。以这种方式实现 nondiff_argnums 是一个错误!

PR #4008 修复了 custom_jvpcustom_vjp 的所有词法闭包问题。 太棒了!也就是说,现在 custom_jvpcustom_vjp 函数和规则可以尽情地闭包在 Tracer 上。对于所有非自动微分转换,事情将正常工作。对于自动微分转换,我们将收到一条关于我们无法相对于 custom_jvpcustom_vjp 闭包的值进行微分的清晰错误消息。

检测到相对于闭包值对 custom_jvp 函数进行微分。这是不支持的,因为 custom JVP 规则只指定了如何相对于显式输入参数对 custom_jvp 函数进行微分。

尝试将闭包值作为参数传递到 custom_jvp 函数中,并调整 custom_jvp 规则。

通过以这种方式加强和稳固 custom_jvpcustom_vjp,我们发现允许 custom_vjp 在其 nondiff_argnums 中接受 Tracer 将需要大量的簿记:我们需要重写用户的 fwd 函数以将值作为残差返回,并重写用户的 bwd 函数以将它们作为正常残差接受(而不是作为特殊的前导参数接受,就像 nondiff_argnums 那样)。这看起来可能可行,直到你考虑到我们必须如何处理任意 pytree!此外,这种复杂性不是必需的:如果用户代码像对待常规参数和残差一样对待类似数组的不可微分参数,那么一切都能正常工作。(在 #4039 之前,JAX 可能会抱怨将整数值输入和输出包含在自动微分中,但在 #4039 之后,这些将正常工作!)

custom_vjp 不同,让 custom_jvp 与是 Tracernondiff_argnums 参数一起工作很容易。因此,这些更新只需要在 custom_vjp 中进行。