使用 jax.checkpoint (jax.remat) 的梯度检查点#

在本教程中,您将学习如何使用 jax.checkpoint()(也称为 jax.remat())来控制 JAX 自动微分的保存值,这在机器学习中特别有用。

如果你是自动微分(autodiff)新手,或者需要复习一下,JAX 提供了 自动微分高级自动微分 教程。

要点总结: 使用 jax.checkpoint() 装饰器(别名为 jax.remat())与 jax.grad() 结合使用,以控制在前向传播中保存哪些中间值,以及在反向传播中重新计算哪些中间值,从而在内存和 FLOPs 之间进行权衡。

如果你不使用 jax.checkpoint()jax.grad(f)(x) 的前向传播会存储雅可比系数和其他中间值,以便在反向传播中使用。这些保存的值称为残差

注意: 不要错过 实用注意事项,其中讨论了 jax.checkpoint() 如何与 jax.jit() 交互。

import jax
import jax.numpy as jnp

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)

通过将 jax.checkpoint() 应用于子函数,作为装饰器或在特定应用位置,你可以强制 JAX 不保存该子函数的任何残差。相反,可能只保存 jax.checkpoint() 装饰函数的输入,并且在反向传播中使用的任何残差都根据需要从这些输入重新计算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)

这里,两个 sin 应用的值被保存,因为它们是后续 jax.checkpoint() 装饰的 g 函数的参数,并且可以保存 jax.checkpoint() 装饰的函数的输入。但是不保存 cos 应用的值。

为了控制哪些值可以保存,而无需编辑要微分的函数的定义,你可以使用重构策略。以下示例仅保存没有批次维度的 dot 操作的结果(因为它们通常受 FLOP 限制,因此值得保存而不是重新计算)

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)

你还可以使用策略来引用你使用 jax.ad_checkpoint.checkpoint_name() 命名的中间值

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/2296542172.py:4 (f4)

在尝试这些玩具示例时,你可以使用此笔记本中定义的自定义 print_fwd_bwd 实用程序更仔细地了解正在发生的事情

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                         
  forward computation:                                         backward computation:                                                                     
                                                                                                                                                         
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let    { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6]  
      e:f32[5] = dot_general[                                      i:f32[7] j:f32[7]. let                                                                
        dimension_numbers=(([1], [0]), ([], []))                   k:f32[7] = mul j i                                                                    
        preferred_element_type=float32                             l:f32[6] = dot_general[                                                               
      ] a d                                                          dimension_numbers=(([0], [0]), ([], []))                                            
      f:f32[5] = sin e                                               preferred_element_type=float32                                                      
      g:f32[5] = cos e                                             ] k h                                                                                 
      h:f32[6] = dot_general[                                      m:f32[7,6] = dot_general[                                                             
        dimension_numbers=(([1], [0]), ([], []))                     dimension_numbers=(([], []), ([], []))                                              
        preferred_element_type=float32                               preferred_element_type=float32                                                      
      ] b f                                                        ] k g                                                                                 
      i:f32[6] = sin h                                             n:f32[6] = mul l f                                                                    
      j:f32[6] = cos h                                             o:f32[5] = dot_general[                                                               
      k:f32[7] = dot_general[                                        dimension_numbers=(([0], [0]), ([], []))                                            
        dimension_numbers=(([1], [0]), ([], []))                     preferred_element_type=float32                                                      
        preferred_element_type=float32                             ] n e                                                                                 
      ] c i                                                        p:f32[6,5] = dot_general[                                                             
      l:f32[7] = sin k                                               dimension_numbers=(([], []), ([], []))                                              
      m:f32[7] = cos k                                               preferred_element_type=float32                                                      
    in (l, d, a, g, f, b, j, i, c, m) }                            ] n d                                                                                 
                                                                   q:f32[5] = mul o c                                                                    
                                                                   r:f32[4] = dot_general[                                                               
                                                                     dimension_numbers=(([0], [0]), ([], []))                                            
                                                                     preferred_element_type=float32                                                      
                                                                   ] q b                                                                                 
                                                                   s:f32[5,4] = dot_general[                                                             
                                                                     dimension_numbers=(([], []), ([], []))                                              
                                                                     preferred_element_type=float32                                                      
                                                                   ] q a                                                                                 
                                                                 in (s, p, m, r) }                                                                       
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                        
  forward computation:                                                   backward computation:                                                                          
                                                                                                                                                                        
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let              { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[                                                i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
        dimension_numbers=(([1], [0]), ([], []))                               differentiated=True                                                                      
        preferred_element_type=float32                                         jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      ] a d                                                                        s:f32[4] t:f32[7]. let                                                               
      f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e              u:f32[5] = sin m                                                                     
      g:f32[5] = sin f                                                             v:f32[5] = cos m                                                                     
      h:f32[6] = dot_general[                                                      w:f32[6] = sin n                                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   x:f32[6] = cos n                                                                     
        preferred_element_type=float32                                             y:f32[7] = cos o                                                                     
      ] b g                                                                        z:f32[7] = mul t y                                                                   
      i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h              ba:f32[6] = dot_general[                                                             
      j:f32[6] = sin i                                                               dimension_numbers=(([0], [0]), ([], []))                                           
      k:f32[7] = dot_general[                                                        preferred_element_type=float32                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   ] z r                                                                                
        preferred_element_type=float32                                             bb:f32[6] = mul ba x                                                                 
      ] c j                                                                        bc:f32[5] = dot_general[                                                             
      l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k                dimension_numbers=(([0], [0]), ([], []))                                           
      m:f32[7] = sin l                                                               preferred_element_type=float32                                                     
    in (m, f, i, l, a, b, c, d) }                                                  ] bb q                                                                               
                                                                                   bd:f32[5] = mul bc v                                                                 
                                                                                   be:f32[4] = dot_general[                                                             
                                                                                     dimension_numbers=(([0], [0]), ([], []))                                           
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd p                                                                               
                                                                                   bf:f32[5,4] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd s                                                                               
                                                                                   bg:f32[6,5] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bb u                                                                               
                                                                                   bh:f32[7,6] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] z w                                                                                
                                                                                 in (bf, bg, bh, be) }                                                                  
                                                                               policy=<function dot_with_no_batch_dims_saveable at 0x7f6ca10ebc70>                      
                                                                               prevent_cse=True                                                                         
                                                                             ] a b c d e f g h                                                                          
                                                                           in (i, j, k, l) }                                                                            

让我们逐步思考#

注意: 在继续之前,查看 高级自动微分 教程可能会有所帮助。

jax.checkpoint 基础知识#

jax.linearize()jax.vjp() 中,在如何以及何时计算某些值方面具有灵活性。不同的选择可以在内存使用与 FLOPs 之间进行权衡。JAX 使用 jax.checkpoint() 提供对这些选择的控制。

一种这样的选择是在前向传播中,在输入可用时立即执行雅可比系数计算,还是在反向传播中,在需要系数之前执行计算。考虑 sin_vjp 的示例

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一种有效的实现方式是在反向传播而不是在前向传播中计算 jnp.cos(x) 的值

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

对于这个特定的函数,两个版本使用的内存量相同,尽管你减少了原始计算(前向传播)的 FLOPs,并增加了余切计算(反向传播)的 FLOPs。

在函数组合方面还有另一种选择。回想一下两个函数组合的 VJP 规则

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一种选择是

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

用文字来说,这种替代实现不会在前向传播中计算 g_vjp 或其闭包中的残差值。相反,它仅在反向传播 f_bwd2 中计算它们。这意味着 f_vjp_checkpoint 需要更少的内存:如果 gh 各自的残差需要相似的内存量,并且都远大于 x,则 f_vjp_checkpoint(x) 生成的函数所需的内存是 f_vjp(x) 的一半!

你付出的代价是冗余工作:在 f_bwd2 中,你必须重新评估 g(x) 作为 jax.vjp(g, x) 的一部分,只是为了丢弃其值(在 _, g_vjp = jax.vjp(g, x) 行的下划线变量中)。

你可以通过在原始函数 f 的替代定义中使用 jax.checkpoint() 来在自动微分中获得此 VJP 行为,而无需直接编写 VJP 函数

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

换句话说,你将 jax.checkpoint() 应用于 gf 的第一阶段 — 而不是 f 本身。这样,当你评估 jax.grad(f_checkpoint)(x) 时,你将获得类似以下的计算

  1. 运行 g 的前向传播,丢弃残差值。

  2. 运行 h 的前向传播,保存残差。

  3. 运行 h 的反向传播,消耗步骤 2 中的残差。

  4. 重新运行 g 的前向传播,保存残差。

  5. 运行 g 的反向传播,消耗步骤 4 中的残差。

也就是说,通过评估 jax.grad(f_checkpoint)(x),我们将获得与以下相同的计算

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般来说,jax.checkpoint(foo) 是一个新函数,它具有与 foo 相同的输入输出行为,但在自动微分下,尤其是在 jax.linearize()jax.vjp() (及其包装器,如 jax.grad())下行为不同,但在 jax.jvp() 下则不然。当微分时,仅将 jax.checkpoint() 微分函数的输入存储在前向传播中。在反向传播中,将重新计算残差(foo 的中间值及其反向传播所需的雅可比系数)。

请注意,如果 f = lambda x: h(g(x)) 是您想要微分的函数(换句话说,如果您想应用 jax.grad(f)),那么对 f 本身应用 jax.checkpoint() 不会节省任何内存。这是因为计算 jax.grad(jax.checkpoint(f))(x) 会导致如下计算:

  1. 运行前向传播,丢弃所有残差。

  2. 立即重新运行前向传播,保存残差。

  3. 运行反向传播,使用步骤 2 中的残差。

在代码中,您会有类似这样的代码:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

f 的第二阶段 h 应用 jax.checkpoint() 也不会节省任何内存。这是因为计算 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 会导致如下计算:

  1. 运行 g 的前向传播,保存残差。

  2. 运行 h 的前向传播,丢弃残差。

  3. 立即重新运行 h 的前向传播,保存残差。

  4. 运行 h 的反向传播,使用步骤 3 中的残差。

  5. 运行 g 的反向传播,使用步骤 1 中的残差。

在代码中,您会有类似这样的代码:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更一般地,如果您有一系列函数组合,例如 f = lambda x: f3(f2(f1(x))),并且您对评估 jax.grad(f) 感兴趣,您可以说您:

  • 不应该对整个函数 f 应用 jax.checkpoint(),因为这不会节省任何内存(并且会执行浪费的重新计算)。

  • 不应该对最后一个子函数 f3 应用 jax.checkpoint(),因为这不会节省任何内存(并且会执行浪费的重新计算)。

  • 可以对 f1f2 或它们的组合 lambda x: f2(f1(x)) 应用 jax.checkpoint(),因为其中任何一个都可能节省内存,并且会表达不同的内存/重新计算权衡。

可保存内容的自定义策略#

如目前所示,使用 jax.checkpoint() 会从一个极端切换到另一个极端:

  • 如果没有 jax.checkpoint(),JAX 的自动微分倾向于在前向传播中计算所有可能的值,并将其存储以供反向传播使用。

  • 使用 jax.checkpoint() 装饰器,您将在前向传播中计算尽可能少的值,并在反向传播中根据需要重新计算值。

为了在这两个极端之间操作,保存一些内容而不是其他内容,您可以谨慎地将 jax.checkpoint() 装饰器放置在子函数上。但这需要编辑要微分的函数,例如模型代码,这可能不方便。也很难尝试各种变体。

因此,另一种方法是使用 jax.checkpoint()policy 参数。策略是一个可调用对象(即函数),它将一阶原始应用程序的类型级规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为残差(或者是否必须在(余)切线计算中根据需要重新计算)。为了编写健壮的代码,应该从 jax.checkpoint_policies() 的属性中选择策略,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable(),因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑这个要微分的函数:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)

也许您只想保存没有批处理维度的矩阵乘法的结果(因为它们可能是受 FLOP 而不是内存限制的),而不是在前向传播中保存这么多值。您可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable() 来实现这一点:

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:8 (predict)

另请注意,通过提供策略,您不需要编辑定义 losspredictlayer 的代码。如果您想在调用代码(例如训练脚本)中尝试策略,而无需更改库代码(例如,神经网络库),这将特别方便。

某些策略可以引用使用 jax.ad_checkpoint.checkpoint_name() 命名的值:

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

就其本身而言,jax.ad_checkpoint import.checkpoint_name() 只是一个恒等函数。但由于某些策略函数知道要查找它们,因此您可以使用名称来控制是否将 jax.ad_checkpoint import.checkpoint_name() 输出的某些值视为可保存的:

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y

另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names

用于卸载的自定义策略#

您可能会考虑在检查点以节省加速器内存时卸载到 CPU 内存而不是重新计算。jax.checkpoint_policies.offload_dot_with_no_batch_dims 可以将没有批处理维度的矩阵乘法的结果卸载到 CPU。

from jax.ad_checkpoint import checkpoint

def checkpoint_offload_dot_with_no_batch_dims(self):
  policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
      "device", "pinned_host")

  @functools.partial(checkpoint, policy=policy)
  def f(x):
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.sum(x)
    return x

JAX 的一个检查点策略允许将指定的检查点名称卸载到 CPU。此策略通过 jax.checkpoint_policies.save_and_offload_only_these_names 实现,该策略有四个参数:names_which_can_be_savednames_which_can_be_offloaded、卸载源和目标。在 names_which_can_be_saved 中列出的名称保留在设备上,在 names_which_can_be_offloaded 中列出的名称被移动到 CPU 内存,其他名称或没有名称的操作被重新计算。例如,如果我们有检查点名称 yzw,则可以将 y 保存在设备上,可以将 z 卸载到 CPU 内存,并且可以重新计算 w

from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu

def checkpoint_names_saved_offloaded_recomputed(self):
  mesh = jtu.create_mesh((2,), ("x",))
  shape = (256, 128)
  np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
  s = NamedSharding(mesh, P("x"))
  inp = jax.device_put(np_inp, s)

  policy = jax.checkpoint_policies.save_and_offload_only_these_names(
      names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
      offload_src='device', offload_dst='pinned_host')

  @functools.partial(checkpoint, policy=policy)
  def f(x):
    def g(ys, _):
      y, _ = ys
      y = checkpoint_name(jnp.sin(y), "y")
      z = checkpoint_name(jnp.sin(y), "z")
      z = z.T
      w = checkpoint_name(jnp.sin(z), "w")
      return (w.T, jnp.sum(w)), None
    _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
    return scan_out

该代码定义了一个函数 f,该函数使用自定义策略应用检查点。此策略确定在执行期间可以保存或卸载哪些计算。在 f 内部,有一个嵌套函数 g 执行核心计算。jax.lax.scan 函数用于对输入数据重复应用 g

策略列表#

策略包括:

  • everything_saveable(默认策略,就像根本没有使用 jax.checkpoint 一样)

  • nothing_saveable(即,重新物化所有内容,就像根本没有使用自定义策略一样)

  • dots_saveable 或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存任何值,但使用给定的任何名称的 checkpoint_name 的输出除外)

  • save_any_names_but_these(仅保存命名值,即 checkpoint_name 的任何输出,但具有给定名称的输出除外)

  • save_only_these_names(仅保存命名值,并且仅在给定名称中保存)

  • offload_dot_with_no_batch_dimsdots_with_no_batch_dims_saveable 相同,但会卸载到 CPU 内存,而不是重新计算。

  • save_and_offload_only_these_namessave_only_these_names 相同,但会卸载到 CPU 内存,而不是重新计算。

  • save_from_both_policies(policy_1, policy_2) (类似于逻辑 or,因此如果根据 policy_1 *或* policy_2 可保存残差,则该残差可保存)

策略仅指示哪些是可保存的;只有在反向传播实际需要时,才会保存值。

高级:递归 jax.checkpoint#

通过以正确的方式应用 jax.checkpoint(),可以在内存使用和(重新)计算之间进行许多权衡。一个令人惊讶的例子是 *递归* 检查点,您将 jax.checkpoint() 应用于一个函数,该函数本身以某种方式调用了 jax.checkpoint() 修饰的函数,使得来自 \(D\) 个函数的链式组合的内存使用量以 \(\mathcal{O}(\log_2 D)\) 的规模增长,而不是以 \(\mathcal{O}(D)\) 的规模增长。

作为一个玩具示例,考虑多个 jax.numpy.sin() 函数的链式组合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)

通常,存储的残差数量与链的长度成线性关系

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)

但是您可以递归地应用 jax.checkpoint() 来改善缩放

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)

这里的代价,和往常一样,是重新计算:特别是,您最终执行了 \(\mathcal{O}(\log_2 D)\) 倍的 FLOPs

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i h                                                                    
      c:f32[] = cos a                       k:f32[] = mul j g                                                                    
      d:f32[] = sin b                       l:f32[] = mul k f                                                                    
      e:f32[] = cos b                       m:f32[] = mul l e                                                                    
      f:f32[] = sin d                       n:f32[] = mul m d                                                                    
      g:f32[] = cos d                       o:f32[] = mul n c                                                                    
      h:f32[] = sin f                       p:f32[] = mul o b                                                                    
      i:f32[] = cos f                       q:f32[] = mul p a                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, c, e, g, i, k, m, o, q) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d c                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] a f                                           
    in (l, g, a, k, m) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

实用注意事项#

当微分函数被暂存到 XLA 进行编译时 — 例如,通过将 jax.jit() 应用于包含 jax.grad() 调用的函数 — XLA 将自动优化计算,包括关于何时计算或重新生成值的决策。因此,对于 jax.jit() 下的微分函数,通常不需要 jax.checkpoint()。 XLA 将为您优化一切。

一个例外是当使用暂存的控制流时,例如 jax.lax.scan()。跨多个控制流原语(例如,跨前向传播 scan 和相应的反向传播 scan)的自动编译器优化通常不那么彻底。因此,通常最好在传递给 jax.lax.scan() 的主体函数上使用 jax.checkpoint()

例如,在大型 Transformer 模型中,一个常见的模式是将架构表示为 jax.lax.scan() (在各层之间),以便减少编译时间。也就是说,使用一个简单的全连接网络作为类比,而不是像这样编写

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # Weights-bias pair for a layer.
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

相反,使用 jax.lax.scan() 迭代层应用

params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), 
          (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

这种 scan-over-layers 版本减少了编译时间,但通过阻止一些编译器优化,可能会导致梯度计算效率低下。为了缓解这个问题,您可以在扫描的函数上使用 jax.checkpoint()

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

通过以这种方式使用 jax.checkpoint(),您可以手动控制 JAX 的自动微分在正向和反向传播之间保存哪些值,因此不依赖 XLA 优化来为您选择。