JAX 可变换函数的自定义 JVP/VJP 规则#

这是一份设计文档,解释了 jax.custom_jvpjax.custom_vjp 设计和实现背后的某些想法。有关面向用户的文档,请参阅教程笔记本

JAX 中有两种方法可以定义微分规则

  1. 使用 jax.custom_jvpjax.custom_vjp 为已可 JAX 变换的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例及其所有变换规则,例如调用来自其他系统(如求解器、模拟器或通用数值计算系统)的函数。

本文档仅针对 #1。

内容#

目标#

我们希望 **用户** 自定义其代码的前向和/或反向模式微分行为。此自定义

  1. 应该具有清晰一致的语义,说明其工作原理以及如何与其他 JAX 变换组合;以及

  2. 应该在支持用例和工作流方面灵活,就像在 AutogradPyTorch 中一样,包括涉及 Python 控制流微分和用于 NaN 调试的工作流。

作为JAX 开发人员,我们希望编写库函数,例如 logitexpit,这些函数是由其他基本运算定义的,但为了微分的目的,它们具有类似基本运算的行为,也就是说我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更高。特别地,我们不想为像 logitexpit 这样的函数指定 vmapjit 规则。

作为一个扩展目标,我们希望使 JAX 成为一个伟大的环境,供希望为高阶函数(例如 fixed_pointodeint 等)添加自定义微分规则的强大用户使用。本设计文档不会解决这个问题,但我们希望确信自己不会阻止对该问题的良好解决方案。

也就是说,我们的主要目标是

  1. 解决 vmap 移除自定义 jvp 语义问题 (#1249),以及

  2. 允许自定义 VJP 中使用 Python,例如调试 NaN (#1275)。

次要目标是 3. 清理和简化用户体验(符号零、关键字参数等) 4. 向用户可以轻松添加 fixed_pointodeintroot 等的世界迈进。

总的来说,我们希望关闭 #116#1097#1249#1275#1366#1723#1670#1875#1938,并替换 custom_transforms 机制(来自 #636#818 以及其他)。

非目标#

以下是我们打算实现的目标

  1. custom_transforms 机制旨在为自定义行为提供一种转换通用的机制,原则上(尽管在实践中从未真正使用过)允许用户为任何转换自定义规则,同时以某种方式继承其他转换的“透明”行为。我们将只解决微分(JVP 和 VJP,分别)的自定义问题。 微分是唯一实际请求的案例,通过专门针对微分,我们可以降低复杂性并提高灵活性。要控制所有规则,可以简单地编写一个基本运算。

  2. 我们不会优先考虑数学美学,而会优先考虑用户端的灵活性、清晰度和实现端的简洁性。特别是,虽然自定义 VJP 签名 a -> (b, CT b --o CT a) 在数学上很美观,但如果由于返回值类型中的闭包,它难以在 Python 机制中实现,我们很乐意做一些更明确地处理残差的事情。

  3. 序列化支持,即允许用户将阶段化后的序列化程序表示加载并进一步 JAX 转换,而不是仅仅评估,目前不属于这些自定义 JVP/VJP 转换规则的范围。序列化不仅可能对希望保存其计算表示(并在加载后对其进行转换)的研究人员有用,而且对未来的考虑也很有用,例如在 Python 之外实现 jaxpr 转换,或将 jaxprs 作为 MLIR 方言。通过将此定义为本设计目的的非目标,我们对可以存储 Python 可调用的位置有了更少的限制。

主要问题描述#

vmap 移除自定义 jvp 语义问题#

vmap 移除自定义 jvp 语义问题是指 vmap 与具有 custom_transforms 规则的函数的微分不能很好地组合

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一行 vmap 的梯度结果出乎意料!通常,应用 vmap 或任何非微分转换,都会删除自定义微分规则。(当定义了自定义 VJP 规则时,应用 jvp 会导致失败。)

该问题的存在是因为转换类似于重写,而 vmap 转换实际上重写了该函数,使其不再调用新引入的基本运算,而该基本运算存在自定义规则(因此 grad 不会产生自定义规则的结果)。更详细地说,custom_transforms 机制设置了以下内容,以便评估 f(x) 会应用以下函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一个新的基本运算(为每个 custom_transforms 函数引入,实际上为该函数的每次调用引入),自定义 VJP 规则与此基本运算相关联。当我们评估 grad(f)(x) 时,微分机制会遇到 f_primitive 并使用自定义规则对其进行处理。

然而,由于 f_primitivevmap 透明,也就是说 vmap 通过(实际上是通过内联)操作 f_primitive 的定义,函数 vmap(f) 实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap 使用其底层基本运算及其转换规则来重写函数,完全删除了 f_primitive

更一般地说,由于 vmap(f) 的语义是根据对 f 的调用定义的,所以删除自定义导数规则在语义上是不一致的。也就是说,因为我们定义

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

但是,当 f 定义了自定义导数规则时,这个属性没有被观察到,因为在右边的版本中使用了自定义导数规则,而在左边的版本中没有使用。

这个问题并非特定于 vmap;它适用于所有转换,其中转换函数 f 的语义是根据对该函数 f 的调用定义的,而不是将其重写为另一个函数。 mask 转换也属于此类。微分转换和假设的所有一元函数变为余弦的转换不属于此类。

(自定义规则之间的交互,例如自定义 vmap 规则,可能会变得更加复杂,这表明 custom_transforms 的问题框架过于宽泛。)

Python 灵活性问题#

在 JAX 中,与 AutogradPyTorch 相同,但与 TF1 不同,Python 函数的微分是在执行和追踪函数时执行的。这种行为让用户喜欢,原因有以下几点。

首先,也是最重要的是,它支持基于 pdb 的工作流程,例如检查数值或捕获 NaN。 也就是说,用户可以采用标准的 Python 调试器和其他 Python 原生的工具来调试代码,甚至可以检查运行时值,以便了解示例中的数值行为,并捕获诸如 NaN 之类的根本运行时错误。事实上,仅当我在与本设计相关的 PR 上工作时,特别是在 odeint 基本运算上工作时,我多次使用运行时值检查来调试问题,这让我更加确信这是一个 Python 中的关键用户工作流程。我多次在 JAX 和 Autograd 中使用的一个特别方便的技巧是,可以在自定义 VJP 规则中插入一个调试器断点,以便在反向传播过程中的特定点进入调试器。

其次,它允许 Python 原生控制流的微分。 我们不确定这在最终软件工件中被使用的频率,但当用户第一次尝试使用 JAX 或 Autograd 时,他们经常会被这种自由所打动。我们将其包含在我们的 JAX 和 Autograd 自述文件、幻灯片和演示文稿的顶部是有原因的。放弃这种能力将是 Autograd 的退步。我们希望 JAX 拥有最好的自动微分。

然而,custom_transforms 机制没有提供这种 Python 支持的灵活性。也就是说,由于它是在用户函数和自定义微分规则的 Python 代码的基础上预先形成 jaxpr 实现的,因此以下代码会导致抽象值追踪错误

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案思路#

主要思路是dougalm@ 已经用 core.call 解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务,视为一个新的 Python 级别的调用基本运算(不会添加到 jaxpr 语言中;见下文)。这个新的调用基本运算与 core.call 一样,也与用户 Python 函数相关联,但此外还具有表示 JVP 规则的第二个 Python 可调用对象。让我们将这个新的调用基本运算称为 custom_jvp_call

vmap 这样的转换与 custom_jvp_call 的交互方式与 core.call 相同:它们实际上会直接通过它,并应用于底层的 Python 可调用对象。以图表的形式表示,为了方便起见,使用基本运算的柯里化版本来表示,类似于 vmap 通过应用于要调用的函数与 core.call 相互作用的方式

vmap(call(f)) == call(vmap(f))

对于新的基元 custom_jvp_call,我们只需对它包含的两个函数应用 vmap

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

这种行为意味着我们已经解决了 vmap-removes-custom-jvp 语义问题

jvp 变换按预期的方式交互:它只调用 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

因为 custom_jvp_call 的行为类似于 core.call(而不是类似于 xla.xla_call),因为它不会提高其输入的抽象级别(因为它没有延迟任何东西或暂存任何东西),这意味着我们已经解决了 Python 灵活性问题:用户 Python 函数没有约束(除了 jvpvjp 所需的常规函数式编程约束)。

评估和编译呢?这两种方法是“退出”JAX 系统的两种方式,因为在此步骤之后无法应用任何其他变换。因此,它们的规则非常简单。

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果 JVP 规则还没有将 custom_jvp_call(f, f_jvp) 重写为 f_jvp,当我们使用 eval 进行评估或使用 jit 暂存到 XLA 时,微分永远不会被应用,因此我们只需忽略 f_jvp 并且像 core.call 一样执行。但是,由于下一个讨论的细微差别, custom_jvp_call 的部分评估规则必须更复杂一些,因为部分评估不仅用于使用 jit 暂存到 XLA。

唯一剩下的细微差别与“初始样式”的 jaxpr 生成基元(如 lax.scan)及其变换规则有关。它们表示与编译不同的“暂存到 jaxpr”类型,因为我们可以在暂存的 jaxpr 上执行额外的变换。也就是说,当 lax.scan 生成 jaxpr 时,它不会退出变换系统,因为当我们对 lax.scan 应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 所代表的函数。

另一种描述细微差别的说法是,像 lax.scan 这样的初始样式基元依赖于在保留语义的同时将 jaxpr 来回转换到 Python 可调用对象的能力。这意味着也要保留自定义微分规则的语义。

解决方案是使用一些动态作用域:当我们为初始样式的基元(如 lax_control_flow.py 中的那些基元)暂存到 jaxpr 时,我们在全局跟踪状态上设置一个位。当该位被设置时,我们不使用最终样式的 custom_jvp_call 基元,而是使用初始样式的 custom_jvp_call_jaxpr 基元,并将函数 ff_jvp 提前跟踪到 jaxpr 以便于初始样式的处理。 custom_jvp_call_jaxpr 基元在其他方面类似于最终样式版本。

(注:虽然从道德上讲,我们在绑定 custom_jvp_call_jaxpr 之前为 ff_jvp 都生成了 jaxpr,但我们需要延迟生成 f_jvp 的 jaxpr,因为它可能会调用自定义 JVP 函数,因此急切处理会导致无限递归。我们通过一个 thunk 来延迟 jaxpr 生成。)

如果我们放弃了 Python 灵活性问题,我们就可以只使用 custom_jvp_call_jaxpr,而无需单独的 Python 级基元 custom_jvp_call

API#

a -> b 函数的自定义 JVP 通过 (a, Ta) -> (b, T b) 函数指定。

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分旁注:为了使该规则适用于高阶微分,必须在 f_jvp 的主体中调用 f;这排除了在 f 的内部和切线计算之间进行一些工作共享的可能性。)

a -> b 函数的自定义 VJP 通过 a -> (b, c) 正向传递函数配对一个 (c, CT b) -> CT 反向传递函数来指定。

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名 a -> (b, CT b --o CT a) 更加美观,但支持它会使实现更加复杂,并且可能需要牺牲表达能力的愿望。基本原因是 Python 可调用对象是不透明的(除非我们急切地将其跟踪到 jaxpr,这会对表达能力施加约束),在这种情况下,我们可能会返回一个在闭包中包含 vmap 跟踪器,而我们需要在正向传递期间了解这些跟踪器。

我们可以添加一些便捷包装器,例如为一次定义一个参数的 JVP 规则(就像我们在内部为基元所做的那样)。但是,由于此提案本身已经足够复杂,所以我决定不添加便捷层;现在让我们保持简单。

API 中还有一些其他功能。

  • 输入和输出类型 abc 可以是 jaxtypes 的任意 pytree。

  • 当可以使用 inspect 模块将它们解析为位置时,支持按名称传递参数(关键字参数)。这在 Python 3 改进的程序化检查参数签名能力方面是一个小实验。我认为它是健全的,但并不完整,这是一个可以接受的地方。(另见 #2069。)

  • 可以使用 nondiff_argnums 将参数标记为不可微分,并且与 jitstatic_argnums 一样,这些参数不必是 JAX 类型。我们需要为如何将这些参数传递给规则设置一个约定。对于具有类型签名 (d, a) -> b 的原始函数,其中 d 表示不可微分类型,JVP 规则的签名是 (a, T a, d) -> T b,而 VJP 规则的反向分量签名是 (d, c, CT b) -> CT a。也就是说,对于自定义 JVP 规则,不可微分参数按顺序在 primalstangents 之后传递,并且对于自定义 VJP 规则的反向函数,不可微分参数按顺序在残差之前传递。

实现说明#

  • 更新了 jax.experimental.odeint

    • 由于 odeint 是自定义 VJP 规则的一个非常复杂的使用者,除了更新它使其能够工作之外,我还希望将其修改为新自定义 VJP API 的典型使用者,以此作为测试 API 是否良好的方法。

    • 在此过程中,我对 odeint 的实现进行了其他改进。

      • 删除了展开/未展开的样板代码。

      • 使用 lax.scan 来移除索引更新逻辑。

      • 在简单摆模型基准测试中速度提高了 20% 以上。

  • 为每个自定义导数调用基元(custom_jvp_callcustom_vjp_call)添加了一个自定义绑定方法。它类似于 core.call_bind,不同的是我们不处理 env 跟踪:这些都是错误。

  • 添加了 custom_lin 基元,它被暂存到线性 jaxpr 中,以便在使用自定义 VJP 规则时进行转置。

    • 因为我们的反向模式自动微分被分解为线性化、部分评估和转置,所以我们的自定义 VJP 规则在两个单独的步骤中进行处理:一个是在线性化期间,另一个是在转置期间。

    • 线性化步骤,即 custom_vjp_call 的 JVP 规则,将 custom_lin 应用于切线值; custom_lin 携带着用户的自定义反向传递函数,并且作为一个基元,它只有转置规则。

    • 此机制在 #636 中有更详细的描述。

  • 为了防止