custom_vjpnondiff_argnums 更新指南#

mattjj@ 2020 年 10 月 14 日

本文档假定您熟悉 jax.custom_vjp,如 JAX 可转换 Python 函数的自定义导数规则笔记本中所述。

需要更新的内容#

在 JAX PR #4008 之后,传递到 custom_vjp 函数的 nondiff_argnums 中的参数不能是 Tracer(或 Tracer 的容器),这基本上意味着为了允许任意可转换代码,`nondiff_argnums` 不应该用于数组值参数。相反,nondiff_argnums 应该仅用于非数组值,例如 Python 可调用对象或形状元组或字符串。

凡是我们过去使用 nondiff_argnums 来处理数组值的地方,我们应该只是将这些作为常规参数传递。在 bwd 规则中,我们需要为它们生成值,但我们可以只生成 None 值来指示没有相应的梯度值。

例如,这是编写 clip_gradient方法,当 hi 和/或 lo 是来自某些 JAX 转换的 Tracer 时,这种方法将无法工作。

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, None  # no residual values to save

def clip_gradient_bwd(lo, hi, _, g):
  return (jnp.clip(g, lo, hi),)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

这是的、超棒的方法,它支持任意转换

import jax

@jax.custom_vjp  # no nondiff_argnums!
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save lo and hi values as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # return None for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

如果您使用旧方法而不是新方法,在任何可能出错的情况下(即当有 Tracer 传递到 nondiff_argnums 参数中时),您都会收到一个明显的错误。

这里有一个我们需要使用 nondiff_argnumscustom_vjp 的情况

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
  return f(x)

def skip_app_fwd(f, x):
  return skip_app(f, x), None

def skip_app_bwd(f, _, g):
  return (g,)

skip_app.defvjp(skip_app_fwd, skip_app_bwd)

解释#

Tracer 传递到 nondiff_argnums 参数中始终存在错误。虽然有些情况可以正确工作,但其他情况会导致复杂且令人困惑的错误消息。

该错误的本质是 nondiff_argnums 的实现方式非常类似于词法闭包。但是,当时词法闭包覆盖 Tracer 并非旨在与 custom_jvp/custom_vjp 一起使用。以这种方式实现 nondiff_argnums 是一个错误!

PR #4008 修复了 custom_jvpcustom_vjp 的所有词法闭包问题。 耶!也就是说,现在 custom_jvpcustom_vjp 函数和规则可以随心所欲地闭包覆盖 Tracer。对于所有非自动微分转换,一切都将正常工作。对于自动微分转换,我们将收到一条清晰的错误消息,说明为什么我们不能对 custom_jvpcustom_vjp 关闭的值进行微分

检测到相对于闭包值对 custom_jvp 函数进行了微分。这是不支持的,因为 custom JVP 规则仅指定如何相对于显式输入参数对 custom_jvp 函数进行微分。

尝试将闭包值作为参数传递给 custom_jvp 函数,并调整 custom_jvp 规则。

通过这种方式加强和增强 custom_jvpcustom_vjp,我们发现允许 custom_vjp 在其 nondiff_argnums 中接受 Tracer 将需要大量的簿记工作:我们需要重写用户的 fwd 函数以将值作为残差返回,并重写用户的 bwd 函数以将它们作为普通残差接受(而不是像 nondiff_argnums 那样将它们作为特殊的先导参数接受)。这似乎也许可以管理,直到您仔细考虑我们必须如何处理任意的 pytrees!此外,这种复杂性是不必要的:如果用户代码像处理常规参数和残差一样处理类似数组的不可微分参数,那么一切都已经正常工作。(在 #4039 之前,JAX 可能会抱怨在自动微分中涉及整数值输入和输出,但在 #4039 之后,这些将正常工作!)

custom_vjp 不同,很容易使 custom_jvp 与作为 Tracernondiff_argnums 参数一起工作。因此,这些更新只需要在 custom_vjp 中进行。