JAX 可转换函数的自定义 JVP/VJP 规则#

这是一份设计文档,解释了 jax.custom_jvpjax.custom_vjp 的设计和实现背后的一些思考。有关面向用户的文档,请参阅教程笔记本

在 JAX 中,有两种定义微分规则的方法:

  1. 使用 jax.custom_jvpjax.custom_vjp 为已经是 JAX 可转换的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或通用数值计算系统)的函数。

本文档仅讨论 #1。

目录#

目标#

我们希望用户能够自定义其代码的前向和/或反向模式微分行为。这种自定义

  1. 在工作方式以及如何与其他 JAX 转换组合方面,应具有清晰一致的语义;以及

  2. 灵活地支持AutogradPyTorch 中的用例和工作流程,包括涉及 Python 控制流微分和 NaN 调试的工作流程。

作为 JAX 开发人员,我们希望编写库函数(如 logitexpit),这些函数是根据其他原语定义的,但出于微分的目的,它们具有类似原语的行为,因为我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更高。特别是,我们不希望必须为 logitexpit 等函数指定 vmapjit 规则。

作为一个延伸目标,我们希望使 JAX 成为高级用户添加自定义高阶函数(如 fixed_pointodeint 等)微分规则的理想环境;本设计文档不会解决该问题,但我们希望确信我们不会排除该问题的良好解决方案。

也就是说,我们的主要目标是

  1. 解决 vmap-移除-自定义-jvp 的语义问题 (#1249);以及

  2. 允许在自定义 VJP 中使用 Python,例如调试 NaN (#1275)。

次要目标是 3. 清理和简化用户体验(符号零、kwargs 等)4. 朝着用户可以轻松添加 fixed_pointodeintroot 等的方向迈进。

总而言之,我们希望关闭 #116#1097#1249#1275#1366#1723#1670#1875#1938,并替换 custom_transforms 机制(来自 #636#818 以及其他)。

非目标#

以下是我们打算实现的目标:

  1. custom_transforms 机制旨在提供一种用于自定义行为的转换通用机制,原则上(尽管在实践中从未真正使用过)允许用户自定义任何转换的规则,同时以某种方式继承其他转换的“透明”行为。我们改为仅解决微分(JVP 和 VJP,分别)的自定义问题。微分是实际请求的唯一案例,通过专门针对微分,我们可以降低复杂性并提高灵活性。要控制所有规则,只需编写一个原语即可。

  2. 我们不会优先考虑数学美感,而优先考虑用户端的灵活性和清晰度,以及实现端的简单性。特别是,虽然自定义 VJP 签名 a -> (b, CT b --o CT a) 在数学上令人满意,但如果由于返回类型中的闭包而难以在 Python 机制中实现,那么我们可以更明确地处理残差。

  3. 序列化支持(形式为可以加载并进一步进行 JAX 转换的暂存序列化程序表示,而不是仅进行评估)目前不在这些自定义 JVP/VJP 转换规则的范围内。序列化不仅对于希望保存其计算的某些表示(并在加载后对其进行转换)的研究人员可能有用,而且对于未来的考虑(如在 Python 外部实现 jaxpr 转换或将 jaxpr 作为 MLIR 方言)也可能有用。通过将此定义为本设计的非目标,我们可以减少对可以在何处存储 Python 可调用对象的限制。

主要问题描述#

vmap-移除-自定义-jvp 的语义问题#

vmap-移除-自定义-jvp 的语义问题是 vmap 不能正确地与具有 custom_transforms 规则的函数的微分组合。

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一行 grad-of-vmap 有一个意想不到的结果!通常,应用 vmap 或任何非微分转换都会导致移除自定义微分规则。(定义自定义 VJP 规则时,应用 jvp 会导致失败。)

该问题存在的原因是转换类似于重写,而 vmap 转换有效地重写了函数,使其不再调用新引入的具有自定义规则的原语(因此 grad 不会生成自定义规则的结果)。更详细地说,custom_transforms 机制将设置,以便评估 f(x) 应用函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一个新的原语(为每个 custom_transforms 函数以及实际上为函数的每次调用引入),自定义 VJP 规则与其关联。当我们评估 grad(f)(x) 时,微分机制会遇到 f_primitive 并使用自定义规则对其进行处理。

但是,由于 f_primitive 对于 vmap透明的,因为 vmap 操作(通过内联) f_primitive 的定义,函数 vmap(f) 实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap 根据其底层原语及其转换规则重写函数,完全移除 f_primitive

更普遍地说,因为 vmap(f) 具有根据对 f 的调用定义的语义,所以移除自定义导数规则在语义上是不一致的。也就是说,由于我们定义了

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

但是,当 f 定义了自定义导数规则时,不会观察到此属性,因为自定义导数规则在右侧版本中使用,而不在左侧版本中使用。

此问题并非特定于 vmap;它适用于所有转换,对于这些转换,转换函数 f 的语义是根据对函数 f 的调用定义的,而不是将其重写为另一个函数。mask 转换也属于此类。微分转换和假设的所有一元函数都变成余弦转换不在此类中。

(额外的自定义规则(如自定义 vmap 规则)之间的交互可能会变得更加复杂,这表明 custom_transforms 的问题框架过于宽泛。)

Python 的灵活性问题#

在 JAX 中,与 AutogradPyTorch 类似(但与 TF1 不同),Python 函数的微分是在函数执行和跟踪时进行的。这种行为让用户感到满意,原因有以下几点。

首先也是最重要的,它支持基于 pdb 的工作流程,例如,用于检查数值或捕获 NaN。 也就是说,用户可以使用标准的 Python 调试器和其他 Python 原生工具来调试他们的代码,甚至可以检查运行时值,以了解示例中的数值行为并捕获像 NaN 这样的根本运行时错误。实际上,在处理与此设计对应的 PR 时,特别是在 odeint 原语上,我多次使用运行时值检查来调试问题,这让我更加确信这是 Python 中一个关键的用户工作流程。一个特别方便的技巧,我已经在 JAX 和 Autograd 中多次使用过,就是在自定义 VJP 规则中插入一个调试器断点,以便在反向传播的特定点进入调试器。

其次,它允许对 Python 原生控制流进行微分。 我们不确定在最终的软件制品中实际使用这种情况的频率有多高,但当用户第一次接触 JAX 或 Autograd 时,他们通常会对这种自由度印象深刻。这就是为什么我们在 JAX 和 Autograd 的 README、幻灯片演示和演示中都将其放在首位的原因。放弃这种能力将是从 Autograd 的倒退。我们希望 JAX 拥有最好的自动微分功能。

然而,custom_transforms 机制不提供这种 Python 支持的灵活性。也就是说,因为它是在用户函数和自定义微分规则的 Python 代码中预先形成 jaxpr 的基础上实现的,所以这样的代码会导致抽象值跟踪错误

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案思路#

主要思想是 dougalm@ 已经通过 core.call 解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务,转化为一个新的 Python 级别调用原语(不会添加到 jaxpr 语言中;见下文)。这个新的调用原语与 core.call 一样,有一个与之关联的用户 Python 函数,但此外,还有一个代表 JVP 规则的第二个 Python 可调用对象。我们把这个新的调用原语称为 custom_jvp_call

vmap 这样的变换与 custom_jvp_call 的交互方式与 core.call 的交互方式相同:它们实际上是直接通过它,并应用于底层的 Python 可调用对象。从示意图上看,为了方便起见,使用原语的柯里化版本进行编写,类似于 vmap 通过应用于要调用的函数来与 core.call 进行交互的方式

vmap(call(f)) == call(vmap(f))

对于新的原语 custom_jvp_call,我们只需将 vmap 应用于它所包含的两个函数

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

这种行为意味着我们已经解决了 vmap 删除自定义 JVP 语义的问题

jvp 变换的交互方式与预期一致:它只是调用 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

因为 custom_jvp_call 的行为类似于 core.call(而不是像 xla.xla_call),因为它没有提高其输入的抽象级别(因为它没有延迟任何内容或暂存任何内容),这意味着我们已经解决了 Python 的灵活性问题:用户 Python 函数没有限制(高于 jvpvjp 所要求的通常函数式编程约束)。

那么评估和编译呢?这是两种“退出” JAX 系统的方式,因为在这些步骤之后不能应用额外的变换。因此,它们的规则很简单

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果 JVP 规则尚未将 custom_jvp_call(f, f_jvp) 重写为 f_jvp,当我们通过 eval 进行评估或通过 jit 暂存到 XLA 时,永远不会应用微分,因此我们只是忽略 f_jvp,其行为就像 core.call。但是,由于接下来讨论的复杂情况,custom_jvp_call 的部分评估规则必须更复杂一些,因为部分评估不仅仅用于通过 jit 暂存到 XLA。

唯一剩下的复杂之处与“初始风格”的 jaxpr 形成原语(如 lax.scan)及其变换规则有关。这些代表了与编译不同的“暂存到 jaxpr”的类型,因为我们可以对暂存的 jaxpr 执行额外的变换。也就是说,当 lax.scan 形成 jaxpr 时,它不会退出变换系统,因为当我们对 lax.scan 应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 表示的函数。

表达这种复杂性的另一种方式是,像 lax.scan 这样的初始风格的原语依赖于往返于 jaxpr 并返回到 Python 可调用对象的能力,同时保持语义。这也必须意味着保留自定义微分规则的语义。

解决方案是使用一些动态作用域:当我们将 jaxpr 暂存到初始风格的原语(如 lax_control_flow.py 中的原语)时,我们在全局跟踪状态中设置一个位。当该位被设置时,我们使用初始风格的 custom_jvp_call_jaxpr 原语,而不是使用最终风格的 custom_jvp_call 原语,并预先跟踪函数 ff_jvp 到 jaxpr,以方便初始风格的处理。custom_jvp_call_jaxpr 原语在其他方面与最终风格的版本类似。

(脚注:虽然从本质上讲,我们在绑定 custom_jvp_call_jaxpr 之前为 ff_jvp 都形成了 jaxpr,但我们需要延迟形成 f_jvp 的 jaxpr,因为它可能会调用自定义 JVP 函数,因此急切处理会导致无限递归。我们在 thunk 中延迟该 jaxpr 的形成。)

如果我们放弃 Python 的灵活性问题,我们就可以只使用 custom_jvp_call_jaxpr,而无需单独的 Python 级别原语 custom_jvp_call

API#

对于 a -> b 函数的自定义 JVP,使用 (a, Ta) -> (b, T b) 函数来指定

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分:为了使该规则适用于高阶微分,必须在 f_jvp 的主体中调用 f;这排除了 f 的内部和切线计算之间的一些工作共享。)

对于 a -> b 函数的自定义 VJP,使用 a -> (b, c) 前向传递函数与 (c, CT b) -> CT 反向传递函数配对来指定

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名 a -> (b, CT b --o CT a) 更美观,但支持它会使实现更加复杂,并可能需要牺牲可表达性。基本原因是 Python 可调用对象是不透明的(除非我们急切地将它们跟踪到 jaxpr,这会施加可表达性的约束),并且在这种情况下,我们可能会返回一个可调用对象,其闭包内部有 vmap 跟踪器,我们需要在前向传递期间了解它们。

我们可以添加便利包装器,例如,一次为一个参数定义 JVP 规则(就像我们在内部为原语所做的那样)。但是,由于这个提案已经足够复杂,所以我决定不使用便利层;现在让我们保持最小化。

API 还有其他一些花哨的功能

  • 输入和输出类型 abc 可以是 jaxtypes 的任意 pytree。

  • 当可以使用 inspect 模块将它们解析为位置参数时,支持按名称传递参数(关键字参数)。这有点像使用 Python 3 改进的以编程方式检查参数签名能力的实验。我认为它是合理的,但并不完整,这已经足够好了。(另请参阅 #2069。)

  • 可以使用 nondiff_argnums 将参数标记为不可微分的,并且与 jitstatic_argnums 一样,这些参数不必是 JAX 类型。我们需要为如何将这些参数传递给规则设置约定。对于具有类型签名 (d, a) -> b 的原始函数,其中 d 表示不可微分的类型,JVP 规则的签名是 (a, T a, d) -> T b,而 VJP 规则的反向组件签名是 (d, c, CT b) -> CT a。也就是说,不可微分的参数在自定义 JVP 规则中,在 primalstangents 之后按顺序传递,并在自定义 VJP 规则的反向函数中,在残差之前按顺序传递。

实现说明#

  • 更新的 jax.experimental.odeint

    • 由于 odeint 是自定义 VJP 规则的一个非常复杂的用户,除了将其更新为可以工作之外,我还想将其修改为新的自定义 VJP API 的规范用户,以此来测试该 API 是否是好的。

    • 在此过程中,我对 odeint 的实现进行了其他改进

      • 删除 raveling/unraveling 样板代码

      • 利用 lax.scan 删除索引更新逻辑

      • 在简单摆锤基准测试中提速 20% 以上

  • 在每个转换中为自定义导数调用原语 custom_jvp_callcustom_vjp_call 添加了一个自定义绑定方法。它类似于 core.call_bind,只是我们不处理 env 追踪:这些只是错误。

  • 添加了 custom_lin 原语,它被分阶段输出到线性 jaxprs 中,以便在使用自定义 VJP 规则时进行转置。

    • 由于我们的反向模式自动微分被分解为线性化、部分评估和转置,因此我们的自定义 VJP 规则分两个单独的步骤进行处理:一个在线性化期间,另一个在转置期间。

    • 线性化步骤,即 custom_vjp_call 的 JVP 规则,将 custom_lin 应用于切线值;custom_lin 带有用户自定义的反向传递函数,并且作为原语,它仅具有转置规则。

    • 此机制在 #636 中有更详细的描述。

  • 为了防止