使用 jax.checkpoint (jax.remat) 进行梯度检查点#

在本教程中,您将学习如何使用 jax.checkpoint()(也称为 jax.remat())控制 JAX 自动微分的保存值,这在机器学习中特别有用。

如果您不熟悉自动微分 (autodiff) 或需要复习,JAX 提供了 自动微分高级自动微分 教程。

概括来说 使用 jax.checkpoint() 装饰器(别名为 jax.remat())与 jax.grad() 配合,控制前向传递中保存的中间值和反向传递中重新计算的中间值,从而在内存和浮点运算之间进行权衡。

如果不使用 jax.checkpoint(),则 jax.grad(f)(x) 的前向传递会存储雅可比系数和其他中间值,以便在反向传递期间使用。这些保存的值称为残差

注意: 请务必阅读 实用说明,了解 jax.checkpoint() 如何与 jax.jit() 进行交互。

import jax
import jax.numpy as jnp

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_809/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_809/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_809/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_809/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_809/1801108376.py:6 (g)

通过将 jax.checkpoint() 应用于子函数(作为装饰器或在特定应用位置),可以强制 JAX 不保存该子函数的任何残差。相反,只有 jax.checkpoint() 装饰的函数的输入可能会被保存,并且在反向传递中使用的任何残差都会根据需要从这些输入重新计算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_809/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_809/1801108376.py:6 (g)

这里,两个 sin 应用的结果被保存,因为它们是随后 jax.checkpoint() 装饰的 g 函数的输入,并且 jax.checkpoint() 装饰的函数的输入可能会被保存。但没有保存 cos 应用的结果。

为了控制哪些值可以保存,而无需编辑要微分的函数的定义,可以使用重新计算策略。以下是一个示例,它只保存没有批处理维度的 dot 操作的结果(因为它们通常受浮点运算限制,因此值得保存而不是重新计算)。

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_809/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_809/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_809/1801108376.py:5 (g)

您还可以使用策略来引用使用 jax.ad_checkpoint.checkpoint_name() 命名的中间值。

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_809/2296542172.py:4 (f4)

在使用这些玩具示例时,您可以使用此笔记本中定义的自定义 print_fwd_bwd 工具更详细地了解正在发生的事情。

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                         
  forward computation:                                         backward computation:                                                                     
                                                                                                                                                         
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let    { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6]  
      e:f32[5] = dot_general[                                      i:f32[7] j:f32[7]. let                                                                
        dimension_numbers=(([1], [0]), ([], []))                   k:f32[7] = mul j i                                                                    
        preferred_element_type=float32                             l:f32[6] = dot_general[                                                               
      ] a d                                                          dimension_numbers=(([0], [0]), ([], []))                                            
      f:f32[5] = sin e                                               preferred_element_type=float32                                                      
      g:f32[5] = cos e                                             ] k h                                                                                 
      h:f32[6] = dot_general[                                      m:f32[7,6] = dot_general[                                                             
        dimension_numbers=(([1], [0]), ([], []))                     dimension_numbers=(([], []), ([], []))                                              
        preferred_element_type=float32                               preferred_element_type=float32                                                      
      ] b f                                                        ] k g                                                                                 
      i:f32[6] = sin h                                             n:f32[6] = mul l f                                                                    
      j:f32[6] = cos h                                             o:f32[5] = dot_general[                                                               
      k:f32[7] = dot_general[                                        dimension_numbers=(([0], [0]), ([], []))                                            
        dimension_numbers=(([1], [0]), ([], []))                     preferred_element_type=float32                                                      
        preferred_element_type=float32                             ] n e                                                                                 
      ] c i                                                        p:f32[6,5] = dot_general[                                                             
      l:f32[7] = sin k                                               dimension_numbers=(([], []), ([], []))                                              
      m:f32[7] = cos k                                               preferred_element_type=float32                                                      
    in (l, d, a, g, f, b, j, i, c, m) }                            ] n d                                                                                 
                                                                   q:f32[5] = mul o c                                                                    
                                                                   r:f32[4] = dot_general[                                                               
                                                                     dimension_numbers=(([0], [0]), ([], []))                                            
                                                                     preferred_element_type=float32                                                      
                                                                   ] q b                                                                                 
                                                                   s:f32[5,4] = dot_general[                                                             
                                                                     dimension_numbers=(([], []), ([], []))                                              
                                                                     preferred_element_type=float32                                                      
                                                                   ] q a                                                                                 
                                                                 in (s, p, m, r) }                                                                       
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                        
  forward computation:                                                   backward computation:                                                                          
                                                                                                                                                                        
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let              { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[                                                i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
        dimension_numbers=(([1], [0]), ([], []))                               differentiated=True                                                                      
        preferred_element_type=float32                                         jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      ] a d                                                                        s:f32[4] t:f32[7]. let                                                               
      f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e              u:f32[5] = sin m                                                                     
      g:f32[5] = sin f                                                             v:f32[5] = cos m                                                                     
      h:f32[6] = dot_general[                                                      w:f32[6] = sin n                                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   x:f32[6] = cos n                                                                     
        preferred_element_type=float32                                             y:f32[7] = cos o                                                                     
      ] b g                                                                        z:f32[7] = mul t y                                                                   
      i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h              ba:f32[6] = dot_general[                                                             
      j:f32[6] = sin i                                                               dimension_numbers=(([0], [0]), ([], []))                                           
      k:f32[7] = dot_general[                                                        preferred_element_type=float32                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   ] z r                                                                                
        preferred_element_type=float32                                             bb:f32[6] = mul ba x                                                                 
      ] c j                                                                        bc:f32[5] = dot_general[                                                             
      l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k                dimension_numbers=(([0], [0]), ([], []))                                           
      m:f32[7] = sin l                                                               preferred_element_type=float32                                                     
    in (m, f, i, l, a, b, c, d) }                                                  ] bb q                                                                               
                                                                                   bd:f32[5] = mul bc v                                                                 
                                                                                   be:f32[4] = dot_general[                                                             
                                                                                     dimension_numbers=(([0], [0]), ([], []))                                           
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd p                                                                               
                                                                                   bf:f32[5,4] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd s                                                                               
                                                                                   bg:f32[6,5] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bb u                                                                               
                                                                                   bh:f32[7,6] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] z w                                                                                
                                                                                 in (bf, bg, bh, be) }                                                                  
                                                                               policy=<function dot_with_no_batch_dims_saveable at 0x7fb7129d8b80>                      
                                                                               prevent_cse=True                                                                         
                                                                             ] a b c d e f g h                                                                          
                                                                           in (i, j, k, l) }                                                                            

让我们一步一步地思考#

注意: 在继续学习之前,最好先查看 高级自动微分 教程。

jax.checkpoint 基础知识#

jax.linearize()jax.vjp() 中,都可以在如何以及何时计算某些值方面进行灵活选择。不同的选择可以权衡内存使用和浮点运算量。JAX 通过 jax.checkpoint() 提供对这些选择的控制。

其中一个选择是是否在前向传递中(一旦输入可用)或在反向传递中(就在需要系数之前)执行雅可比系数计算。考虑 sin_vjp 的示例。

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一种有效的实现是在反向传递中计算 jnp.cos(x) 的值,而不是在前向传递中计算。

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

对于此特定函数,这两个版本使用的内存量相同,但您减少了原始计算(前向传递)的浮点运算量,并增加了余切计算(反向传递)的浮点运算量。

在函数组合方面还有另一个选择。回想一下两个函数组合的 VJP 规则。

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一种选择是

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

用文字来说,这种替代实现不会在前向传递中计算 g_vjp 或其闭包中的残差值。相反,它只在反向传递 f_bwd2 中计算它们。这意味着 f_vjp_checkpoint 需要更少的内存:如果 gh 各自需要类似数量的内存来存储其残差,并且每个残差都远大于 x,那么由 f_vjp_checkpoint(x) 生成的函数所需的内存只有 f_vjp(x) 的一半!

您需要付出的代价是冗余工作:在 f_bwd2 中,您必须重新评估 g(x) 作为 jax.vjp(g, x) 的一部分,只是为了丢弃其值(在 _, g_vjp = jax.vjp(g, x) 行上的下划线变量中)。

您可以在自动微分中获得此 VJP 行为,而无需直接编写 VJP 函数,而是使用 jax.checkpoint() 在原始函数 f 的替代定义中。

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

换句话说,您将 jax.checkpoint() 应用于 gf 的第一阶段),而不是应用于 f 本身。这样,当您评估 jax.grad(f_checkpoint)(x) 时,您将获得类似以下的计算:

  1. 运行 g 的前向传递,丢弃残差值。

  2. 运行 h 的前向传递,保存残差。

  3. 运行 h 的反向传递,使用步骤 2 中的残差。

  4. 重新运行 g 的前向传递,保存残差。

  5. 运行 g 的反向传递,使用步骤 4 中的残差。

也就是说,通过评估 jax.grad(f_checkpoint)(x),我们将获得与以下相同的计算结果:

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般来说,jax.checkpoint(foo) 是一个新的函数,它与 foo 具有相同的输入输出行为,但在自动微分下,特别是在 jax.linearize()jax.vjp()(以及它们的包装器,如 jax.grad())下,但不是在 jax.jvp() 下的行为不同。在微分时,只有 jax.checkpoint() 微分的函数的输入存储在前向传递中。在反向传递中,残差(来自 foo 的中间值以及反向传递所需的雅可比系数值)将被重新计算。

请注意,如果 f = lambda x: h(g(x)) 是您要微分的函数(换句话说,如果您想要应用 jax.grad(f)),则通过将 jax.checkpoint() 应用于 f 本身不会节省任何内存。这是因为评估 jax.grad(jax.checkpoint(f))(x) 将导致类似以下的计算:

  1. 运行前向传递,丢弃所有残差。

  2. 立即重新运行前向传递,保存残差。

  3. 运行反向传递,使用步骤 2 中的残差。

在代码中,您将拥有类似以下的内容:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

您也不会通过将 jax.checkpoint() 应用于 hf 的第二阶段)来节省任何内存。这是因为评估 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 将导致类似以下的计算:

  1. 运行 g 的前向传递,保存残差。

  2. 运行 h 的前向传递,丢弃残差。

  3. 立即重新运行 h 的前向传递,保存残差。

  4. 运行 h 的反向传递,使用步骤 3 中的残差。

  5. 运行 g 的反向传递,使用步骤 1 中的残差。

在代码中,您将拥有类似以下的内容:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更一般地,如果您有一系列函数的组合,例如 f = lambda x: f3(f2(f1(x))),并且您有兴趣评估 jax.grad(f),您可以说:

  • 不应将 jax.checkpoint() 应用于整个函数 f,因为这样做不会节省任何内存(并且会执行浪费性的重新计算)。

  • 不应该将 jax.checkpoint() 应用于最后一个子函数 f3,因为这样做不会节省任何内存(并且会执行浪费的重新计算)。

  • 可以将 jax.checkpoint() 应用于 f1f2 或它们的组合 lambda x: f2(f1(x)),因为这些操作中的任何一个都可能节省内存,并且会体现不同的内存/重新计算权衡。

可保存内容的自定义策略#

如前所述,使用 jax.checkpoint() 会在两种极端情况之间切换

  • 没有 jax.checkpoint(),JAX 的自动微分倾向于在正向传递中计算所有可能的内容,并将其存储以用于反向传递。

  • 使用 jax.checkpoint() 装饰器,您改为在正向传递中尽可能少地计算,并在反向传递中根据需要重新计算值。

为了在这两种极端情况之间进行操作,保存某些内容而不保存其他内容,您可以小心地将 jax.checkpoint() 装饰器放置在子函数上。但这需要编辑要微分的函数(例如模型代码),这可能不方便。它也可能难以尝试不同的变化。

因此,另一种方法是使用 jax.checkpoint()policy 参数。策略是一个可调用对象(即一个函数),它以一阶原始应用的类型级别规范作为输入,并返回一个布尔值,指示相应的输出值是否允许作为残差保存(或者必须在(共)切线计算中根据需要重新计算)。为了编写健壮的代码,策略应从 jax.checkpoint_policies() 上的属性中选择,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable(),因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑此要微分的函数

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_809/4230705069.py:2 (loss)

与其在正向传递中保存如此多的值,也许您只希望保存没有批处理维度的矩阵乘法的结果(因为它们可能是 FLOP 而不是内存绑定)。您可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable() 来做到这一点。

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_809/4230705069.py:8 (predict)

还要注意,通过提供策略,您不需要编辑定义 losspredictlayer 的代码。如果您想在调用代码(例如训练脚本)中试验策略而无需更改库代码(例如神经网络库),这一点尤其方便。

某些策略可以引用使用 jax.ad_checkpoint.checkpoint_name() 命名的值。

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

就其本身而言,jax.ad_checkpoint import.checkpoint_name() 只是一个恒等函数。但由于某些策略函数知道要查找它们,因此您可以使用这些名称来控制 jax.ad_checkpoint import.checkpoint_name() 输出的某些值是否被认为是可保存的。

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_809/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_809/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_809/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_809/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y

另一个引用名称的策略是 jax.checkpoint_policies.save_only_these_names

一些策略是

  • everything_saveable:默认策略,就像根本没有使用 jax.checkpoint() 一样。

  • nothing_saveable:也就是说,重新物化所有内容,就像根本没有使用自定义策略一样。

  • dots_saveable:或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable:或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names:保存任何值,除了使用任何给定名称的 checkpoint_name 的输出。

  • save_any_names_but_these:仅保存命名值(即任何输出)。checkpoint_name:除了给定名称的那些。

  • save_only_these_names:仅保存命名值,并且仅限于给定的名称。

策略仅指示哪些内容是可保存的;只有当反向传递实际需要某个值时,才会保存该值。

高级:递归 jax.checkpoint#

通过以正确的方式应用 jax.checkpoint(),可以表达内存使用和(重新)计算之间的许多权衡。一个令人惊讶的例子是递归检查点,您将 jax.checkpoint() 应用于本身调用 jax.checkpoint() 装饰函数的方式,以便 \(D\) 个函数的链式组合的内存使用量按 \(\mathcal{O}(\log_2 D)\) 而不是 \(\mathcal{O}(D)\) 的比例缩放。

作为一个玩具示例,考虑多个 jax.numpy.sin() 函数的链式组合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)

通常,存储的残差数量与链的长度成线性关系。

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_809/410288286.py:4 (f)

但您可以递归地应用 jax.checkpoint() 来改进缩放。

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_809/1943107544.py:6 (<lambda>)

这里,像往常一样,成本是重新计算:特别是,您最终执行了 \(\mathcal{O}(\log_2 D)\) 倍的 FLOP。

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i h                                                                    
      c:f32[] = cos a                       k:f32[] = mul j g                                                                    
      d:f32[] = sin b                       l:f32[] = mul k f                                                                    
      e:f32[] = cos b                       m:f32[] = mul l e                                                                    
      f:f32[] = sin d                       n:f32[] = mul m d                                                                    
      g:f32[] = cos d                       o:f32[] = mul n c                                                                    
      h:f32[] = sin f                       p:f32[] = mul o b                                                                    
      i:f32[] = cos f                       q:f32[] = mul p a                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, c, e, g, i, k, m, o, q) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d c                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] a f                                           
    in (l, g, a, k, m) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

实践注意事项#

当被微分的函数被分阶段输出到 XLA 进行编译时——例如,通过将 jax.jit() 应用于包含 jax.grad() 调用的函数——XLA 将自动优化计算,包括何时计算或重新物化值的决策。因此,jax.checkpoint() 通常不需要用于 jax.jit() 下的被微分的函数。XLA 会为您优化。

一个例外是使用分阶段输出的控制流,例如 jax.lax.scan()。跨多个控制流原语(例如,跨正向传递 scan 和相应反向传递 scan)的自动编译器优化通常并不那么彻底。因此,在传递给 jax.lax.scan() 的主体函数上使用 jax.checkpoint() 通常是一个好主意。

例如,大型 Transformer 模型 中的一个常见模式是将架构表示为 jax.lax.scan() 上的层,以减少编译时间。也就是说,使用简单的全连接网络作为类比,而不是编写类似这样的内容

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # Weights-bias pair for a layer.
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

相反,使用 jax.lax.scan() 迭代层应用。

params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), 
          (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

此基于扫描的层版本减少了编译时间,但通过阻止某些编译器优化,它可能导致梯度的计算效率低下。为了缓解此问题,您可以在扫描函数上使用 jax.checkpoint()

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

通过以这种方式使用 jax.checkpoint(),您正在手动控制 JAX 的自动微分在正向和反向传递之间保存哪些值,因此不依赖于 XLA 优化为您做出选择。