使用 jax.checkpoint(也称为 jax.remat)控制自动微分的保存值#

import jax
import jax.numpy as jnp

摘要#

jax.checkpoint 装饰器(别名为 jax.remat)与 jax.grad 结合使用,以控制哪些中间值在正向传播中保存,哪些中间值在反向传播中重新计算,从而在内存和 FLOPs 之间进行权衡。

不要错过 实用说明,了解 jax.checkpoint 如何与 jax.jit 交互。

如果不使用 jax.checkpoint,则 jax.grad(f)(x) 的正向传播将保存雅可比系数和其他中间值的值,以便在反向传播中使用。我们将这些保存的值称为残差

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)

通过将 jax.checkpoint 应用于子函数(作为装饰器或在特定应用点),我们强制 JAX 不保存该子函数的任何残差。相反,只有 jax.checkpoint 装饰的函数的输入可能会被保存,并且在反向传播中消耗的任何残差都会根据需要从这些输入重新计算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)

这里,两个 sin 应用的值被保存下来,因为它们是 jax.checkpoint 装饰的 g 函数的后续应用中的参数,并且是 jax.checkpoint 装饰函数的输入,因此可能会被保存。但没有保存 cos 应用的任何值。

为了控制哪些值可以保存而无需编辑要微分的函数的定义,可以使用重物化策略。以下是一个示例,它只保存没有批次维度的dot操作的结果(因为它们通常是FLOP绑定的,因此值得保存而不是重新计算)。

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)

您还可以使用策略来引用使用jax.ad_checkpoint.checkpoint_name命名的中间值。

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)

在玩这些玩具示例时,我们可以使用本笔记本中定义的print_fwd_bwd工具更详细地了解正在发生的事情。

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                                      
  forward computation:                                                        backward computation:                                                                   
                                                                                                                                                                      
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[7]. let                                                            
      f:f32[5] = sin e                                                            k:f32[7] = mul j a                                                                  
      g:f32[5] = cos e                                                            l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c                
      h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f        m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b                
      i:f32[6] = sin h                                                            n:f32[6] = mul l d                                                                  
      j:f32[6] = cos h                                                            o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f                
      k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i        p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e                
      l:f32[7] = sin k                                                            q:f32[5] = mul o g                                                                  
      m:f32[7] = cos k                                                            r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i                
    in (l, m, i, c, j, f, b, g, d, a) }                                           s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h                
                                                                                in (s, p, m, r) }                                                                     
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                             
  forward computation:                                                        backward computation:                                                                          
                                                                                                                                                                             
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
      f:f32[5] = sin e                                                              differentiated=True                                                                      
      g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f          jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      h:f32[6] = sin g                                                                  s:f32[4] t:f32[7]. let                                                               
      i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h              u:f32[5] = sin m                                                                     
      j:f32[7] = sin i                                                                  v:f32[5] = cos m                                                                     
    in (j, e, g, i, a, b, c, d) }                                                       w:f32[6] = sin n                                                                     
                                                                                        x:f32[6] = cos n                                                                     
                                                                                        y:f32[7] = cos o                                                                     
                                                                                        z:f32[7] = mul t y                                                                   
                                                                                        ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r                
                                                                                        bb:f32[6] = mul ba x                                                                 
                                                                                        bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q               
                                                                                        bd:f32[5] = mul bc v                                                                 
                                                                                        be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p               
                                                                                        bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s               
                                                                                        bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u               
                                                                                        bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w                
                                                                                      in (bf, bg, bh, be) }                                                                  
                                                                                    policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>                               
                                                                                    prevent_cse=True                                                                         
                                                                                  ] a b c d e f g h                                                                          
                                                                                in (i, j, k, l) }                                                                            

让我们一步一步地思考#

您可能需要首先(重新)阅读自动微分食谱第 1 部分

jax.checkpoint的基础#

jax.linearizejax.vjp中,如何以及何时计算某些值都存在灵活性。不同的选择可以在内存使用和 FLOP 之间进行权衡。JAX 使用jax.checkpoint来控制这些选择。

其中一个选择是,是否在正向传递时执行雅可比系数计算,只要输入可用,或者在反向传递时执行,就在需要系数之前。考虑sin_vjp的示例。

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一个有效的实现将在反向传递而不是正向传递时计算jnp.cos(x)的值。

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

对于这个特定函数,两个版本的内存使用量相同,尽管我们减少了原始计算(即正向传递)的 FLOP,并增加了余切计算(即反向传递)的 FLOP。

在函数组合方面还有另一个选择。回想一下我们对两个函数组合的 VJP 规则。

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一种选择是

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

换句话说,这个替代实现不会在正向传递时计算g_vjp或其闭包中的残差值。相反,它只在反向传递f_bwd2中计算它们。这意味着f_vjp_checkpoint需要更少的内存:如果gh各自需要类似数量的残差内存,每个残差内存都比x大得多,那么由f_vjp_checkpoint(x)生成的函数需要的内存是f_vjp(x)的一半!

我们付出的代价是重复的工作:在f_bwd2中,我们必须重新评估g(x)作为jax.vjp(g, x)的一部分,只是为了丢弃它的值(在_, g_vjp = jax.vjp(g, x)这行中的下划线变量中)。

我们可以通过使用jax.checkpoint在原始函数f的替代定义中,在自动微分中获得这种 VJP 行为,而无需直接编写 VJP 函数。

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

换句话说,我们将jax.checkpoint应用于g,即f的第一阶段,而不是应用于f本身。这样,当我们评估jax.grad(f_checkpoint)(x)时,我们将得到类似以下的计算。

  1. 运行g的正向传递,丢弃残差值;

  2. 运行h的正向传递,保存残差值;

  3. 运行h的反向传递,消耗步骤 2 中的残差值;

  4. 重新运行g的正向传递,保存残差值;

  5. 运行g的反向传递,消耗步骤 4 中的残差值。

也就是说,通过评估jax.grad(f_checkpoint)(x),我们将得到与以下相同的计算。

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般来说,jax.checkpoint(foo)是一个新的函数,它与foo具有相同的输入输出行为,但在自动微分方面,尤其是在jax.linearizejax.vjp(及其包装器,如jax.grad)下表现不同,但在jax.jvp下表现相同。当被微分时,只有jax.checkpoint微分函数的输入在正向传递时被存储;在反向传递时,残差值(即来自foo及其雅可比系数值,这些值在反向传递时是需要的)将被重新计算。

请注意,如果f = lambda x: h(g(x))是我们想要微分的函数,即如果我们想要应用jax.grad(f),通过将jax.checkpoint应用于f本身,我们不会获得任何内存节省。这是因为评估jax.grad(jax.checkpoint(f))(x)将导致类似以下的计算。

  1. 运行正向传递,丢弃所有残差值;

  2. 立即重新运行正向传递,保存残差值;

  3. 运行反向传递,消耗步骤 2 中的残差值。

也就是说,在代码中,我们将有类似以下的内容。

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

我们还不会通过将jax.checkpoint应用于hf的第二阶段)来获得任何内存节省。这是因为评估jax.grad(lambda x: jax.checkpoint(h)(g(x)))将导致类似以下的计算。

  1. 运行g的正向传递,保存残差值;

  2. 运行h的正向传递,丢弃残差值;

  3. 立即重新运行h的正向传递,保存残差值;

  4. 运行h的反向传递,消耗步骤 3 中的残差值;

  5. 运行g的反向传递,消耗步骤 1 中的残差值。

也就是说,在代码中,我们将有类似以下的内容。

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更一般地说,如果我们有一个函数的链式组合,比如f = lambda x: f3(f2(f1(x))),并且我们对评估jax.grad(f)感兴趣,我们可以说

  • 我们不应该将jax.checkpoint应用于整个函数f,因为那样不会节省任何内存(并且会执行浪费的重新计算);

  • 我们不应该将jax.checkpoint应用于最后一个子函数f3,因为那样不会节省任何内存(并且会执行浪费的重新计算);

  • 我们可以将jax.checkpoint应用于f1f2或它们的组合lambda x: f2(f1(x)),因为任何这些都可能节省内存,并且将表达不同的内存/重新计算权衡。

可保存内容的自定义策略#

如上所述,使用jax.checkpoint会在两种极端之间切换

  • 没有jax.checkpoint,JAX 的自动微分倾向于在正向传递时尽可能地计算所有内容,并将其存储在反向传递中;

  • 使用jax.checkpoint装饰器,我们将在正向传递时尽可能少地计算,并在反向传递时根据需要重新计算值。

为了在这两种极端之间进行操作,保存某些内容,而另一些不保存,我们可以将jax.checkpoint装饰器仔细放置在子函数上。但这需要编辑要微分的函数(例如模型代码),这可能不方便。它也可能难以尝试不同的变化。

因此,另一种选择是使用jax.checkpointpolicy参数。策略是一个可调用对象(即一个函数),它以一阶原始应用程序的类型级规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值作为残差值保存(或者是否必须在(协)切线计算中根据需要重新计算它们)。为了编写健壮的代码,应从jax.checkpoint_policies的属性中选择策略,比如jax.checkpoint_policies.dots_with_no_batch_dims_saveable,因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑要微分的以下函数

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)

除了在正向传递时保存这么多的值,也许我们只希望保存没有批次维度的矩阵乘法的结果(因为它们可能是 FLOP 而不是内存绑定的)。我们可以使用策略jax.checkpoint_policies.dots_with_no_batch_dims_saveable来实现。

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)

请注意,通过提供策略,我们不需要编辑定义losspredictlayer的代码。如果我们想在调用代码(例如训练脚本)中尝试策略,而无需更改库代码(例如神经网络库),这一点尤其方便。

某些策略可以引用使用jax.ad_checkpoint.checkpoint_name命名的值。

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

单独使用 checkpoint_name 只是一个标识函数。但由于某些策略函数知道要查找它们,我们可以使用这些名称来控制是否将 checkpoint_name 输出的某些值视为可保存的。

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'

另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names

一些策略包括:

  • everything_saveable(默认策略,就像根本没有使用 jax.checkpoint 一样)

  • nothing_saveable(即重新计算所有内容,就像根本没有使用自定义策略一样)

  • dots_saveable 或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存除 checkpoint_name 输出的任何值,这些值具有给定的任何名称)

  • save_any_names_but_these(仅保存命名值,即任何 checkpoint_name 的输出,除了那些具有给定名称的值)

  • save_only_these_names(仅保存命名值,并且仅在给定的名称之间)

策略仅指示哪些是可以保存的;只有当某个值实际需要进行反向传递时才会保存它。

高级:递归 jax.checkpoint#

通过以正确的方式应用 jax.checkpoint,可以在内存使用量和(重新)计算之间表达许多权衡。一个令人惊讶的例子是*递归*检查点,我们在其中将 jax.checkpoint 应用于本身调用 jax.checkpoint 装饰函数的函数,这样 \(D\) 个函数的链式组合的内存使用量会按 \(\mathcal{O}(\log_2 D)\) 的比例缩放,而不是 \(\mathcal{O}(D)\)

作为一个玩具示例,请考虑多个 jnp.sin 函数的链式组合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

通常,存储的残差数量与链的长度成线性关系

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

但我们可以递归地应用 jax.checkpoint 来改进缩放

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)

这里,与往常一样,成本是重新计算:特别是,我们最终执行了 \(\mathcal{O}(\log_2 D)\) 倍的 FLOPs

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i a                                                                    
      c:f32[] = cos a                       k:f32[] = mul j b                                                                    
      d:f32[] = sin b                       l:f32[] = mul k c                                                                    
      e:f32[] = cos b                       m:f32[] = mul l d                                                                    
      f:f32[] = sin d                       n:f32[] = mul m e                                                                    
      g:f32[] = cos d                       o:f32[] = mul n f                                                                    
      h:f32[] = sin f                       p:f32[] = mul o g                                                                    
      i:f32[] = cos f                       q:f32[] = mul p h                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, q, o, m, k, i, g, e, c) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d a                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] c f                                           
    in (l, m, k, g, a) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

实用说明#

当被微分的函数被分阶段转移到 XLA 进行编译时,例如通过将 jax.jit 应用于包含 jax.grad 调用的函数,XLA 将自动优化计算,包括有关何时计算或重新计算值的决策。因此,**jax.checkpoint 通常不需要在 jax.jit 下的被微分函数中使用**。XLA 将为您优化这些内容。

一个例外是使用分阶段控制流,例如 jax.lax.scan。跨多个控制流原语(例如,跨正向传递 scan 和相应的反向传递 scan)的自动编译器优化通常不那么彻底。因此,在传递给 jax.lax.scan 的主体函数上使用 jax.checkpoint 通常是一个好主意。

例如,大型 Transformer 模型 中的一种常见模式是将体系结构表示为对层的 jax.lax.scan,以便减少编译时间。也就是说,使用一个简单的全连接网络作为类比,我们不会写这样的东西

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # weights, bias pair for a layer
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

相反,我们将使用 jax.lax.scan 对层应用进行迭代

StackedWeights = jnp.ndarray  # all weight matrices stacked together
StackedBiases = jnp.ndarray   # all bias vectors stacked together

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

此对层的扫描版本减少了编译时间,但通过破坏某些编译器优化,它会导致对梯度的低效计算。为了缓解这个问题,我们将在扫描函数上使用 jax.checkpoint

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

通过这种方式使用 jax.checkpoint,我们手动控制 JAX 的自动微分在正向和反向传递之间保存哪些值,因此不依赖于 XLA 优化来为我们选择。