jax.checkpoint#
- jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())[源代码]#
使
fun
在微分时重新计算内部线性化点。jax.checkpoint()
装饰器,别名为jax.remat()
,提供了一种在自动微分的上下文中权衡计算时间和内存成本的方法,尤其是对于像jax.grad()
和jax.vjp()
这样的反向模式自动微分,但也包括jax.linearize()
。在反向模式中对函数进行微分时,默认情况下,所有线性化点(例如,逐元素非线性原语操作的输入)在评估正向传递时都会被存储,以便在反向传递中重用。这种评估策略可能会导致高内存成本,甚至在内存访问比 FLOPs 昂贵得多的硬件加速器上导致性能不佳。
另一种评估策略是重新计算(即重新物化)某些线性化点,而不是存储它们。这种方法可以减少内存使用,但会增加计算量。
此函数装饰器会生成
fun
的新版本,该版本遵循重新物化策略,而不是默认的存储所有内容的策略。也就是说,它返回fun
的新版本,该版本在进行微分时不会存储任何中间线性化点。相反,这些线性化点会从函数的保存输入中重新计算。请参阅下面的示例。
- 参数:
fun (Callable) – 要将其自动微分评估策略从存储所有中间线性化点的默认值更改为重新计算的函数。其参数和返回值应为数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)及其组合。
prevent_cse (bool) – 可选的、仅限关键字的布尔参数,指示是否在从微分生成的 HLO 中阻止公共子表达式消除 (CSE) 优化。这种 CSE 阻止是有代价的,因为它可能会阻碍其他优化,并且在某些后端(尤其是 GPU)上可能会产生较高的开销。默认值为 True,因为否则在
jit()
或pmap()
下,CSE 可能会破坏此装饰器的目的。但在某些设置中,例如在scan()
中使用时,此 CSE 阻止机制是不必要的,在这种情况下,可以将prevent_cse
设置为 False。static_argnums (int | tuple[int, ...]) – 可选的、int 或 int 序列,一个仅限关键字的参数,指示要为跟踪和缓存目的而专门化的参数值。将参数指定为静态可以避免在跟踪时出现 ConcretizationTypeErrors,但代价是会增加重新跟踪的开销。请参阅下面的示例。
policy (Callable[..., bool] | None | None) – 可选的、仅限关键字的可调用参数。它应该是
jax.checkpoint_policies
的属性之一。该可调用对象将一阶原语应用程序的类型级规范作为输入,并返回一个布尔值,指示相应的输出值是否可以保存为残差(或者如果需要,则必须在(余)切线计算中重新计算)。
- 返回值:
一个函数(可调用对象),其输入/输出行为与
fun
相同,但当使用例如jax.grad()
、jax.vjp()
或jax.linearize()
进行微分时,会重新计算而不是存储中间线性化点,从而可能以额外的计算为代价节省内存。- 返回类型:
Callable
这是一个简单的示例
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.value_and_grad(g)(2.0) (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
在此示例中,无论是否存在
jax.checkpoint()
装饰器,都会生成相同的值。当不存在装饰器时,jnp.cos(2.0)
和jnp.cos(jnp.sin(2.0))
的值将在正向传递中计算,并存储以供反向传递使用,因为它们在反向传递中是必需的,并且仅依赖于原始输入。当使用jax.checkpoint()
时,正向传递将仅计算原始输出,并且仅原始输入 (2.0
) 将被存储以供反向传递使用。那时,jnp.sin(2.0)
的值将被重新计算,同时重新计算jnp.cos(2.0)
和jnp.cos(jnp.sin(2.0))
的值。虽然
jax.checkpoint()
控制从正向传递存储哪些值以供反向传递使用,但评估函数或其 VJP 所需的总内存量取决于该函数的许多其他内部细节。这些细节包括使用了哪些数值原语、它们是如何组合的、jit 和控制流原语(如 scan)在哪里使用,以及其他因素。jax.checkpoint()
装饰器可以递归应用以表达复杂的自动微分重新物化策略。例如>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...
如果
fun
涉及依赖于参数值的 Python 控制流,则可能需要使用static_argnums
参数。例如,考虑一个布尔标志参数from functools import partial @partial(jax.checkpoint, static_argnums=(1,)) def foo(x, is_training): if is_training: ... else: ...
在此示例中,使用
static_argnums
允许if
语句的条件依赖于is_training
的值。使用static_argnums
的代价是它会在调用之间引入重新跟踪的开销:在示例中,每次使用is_training
的新值调用foo
时,都会重新跟踪它。在某些情况下,也需要jax.ensure_compile_time_eval
@partial(jax.checkpoint, static_argnums=(1,)) def foo(x, y): with jax.ensure_compile_time_eval(): y_pos = y > 0 if y_pos: ... else: ...
作为使用
static_argnums
(和jax.ensure_compile_time_eval
)的替代方法,可能更容易在jax.checkpoint()
装饰的函数外部计算某些值,然后将它们封闭起来。