持久性编译缓存#

JAX 具有可选的磁盘缓存,用于编译后的程序。如果启用,JAX 会将编译后的程序副本存储在磁盘上,这可以在重复运行相同或类似任务时节省重新编译时间。

用法#

快速入门#

import jax
import jax.numpy as jnp

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

@jax.jit
def f(x):
  return x + 1

x = jnp.zeros((2, 2))
f(x)

设置缓存目录#

缓存位置 设置时,编译缓存被启用。这应该在第一次编译之前完成。按如下方式设置位置

(1) 使用环境变量

在 shell 中,运行脚本之前

export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

或者在 Python 脚本的顶部

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

(2) 使用 jax.config.update()

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

(3) 使用 set_cache_dir()

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")

缓存阈值#

  • jax_persistent_cache_min_compile_time_secs: 只有当编译时间长于指定值时,才会将计算写入持久缓存。默认值为 1.0 秒。

  • jax_persistent_cache_min_entry_size_bytes: 将缓存到持久编译缓存中的条目的最小大小(以字节为单位)

    • -1: 禁用大小限制并阻止覆盖。

    • 保留默认值(0)以允许覆盖。覆盖通常会确保最小大小对于用于缓存的文件系统来说是最优的。

    • > 0: 所需的实际最小大小; 不覆盖。

请注意,这两个条件都需要满足,才能缓存函数。

Google Cloud#

在 Google Cloud 上运行时,编译缓存可以放在 Google Cloud Storage (GCS) 存储桶中。我们建议使用以下配置

  • 在与工作负载运行所在区域相同的区域中创建存储桶。

  • 在与工作负载的 VM(s) 相同的项目中创建存储桶。确保已设置权限,以便 VM(s) 可以写入存储桶。

  • 对于较小的工作负载,无需复制。较大的工作负载可能会从复制中受益。

  • 对存储桶使用“标准”作为默认存储类别。

  • 将软删除策略设置为最短:7 天。

  • 将对象生命周期设置为预期工作负载运行时间。例如,如果预期工作负载运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖在整个运行期间发生的重启。对生命周期条件使用 age,对操作使用 Delete。有关详细信息,请参阅 对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。

  • 支持所有加密策略。

假设 gs://jax-cache 是 GCS 存储桶,请按如下方式设置缓存位置

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")

工作原理#

缓存键是包含以下参数的已编译函数的签名

  • 由被散列的 JAX 函数的非优化 HLO 捕获的函数执行的计算

  • jaxlib 版本

  • 相关的 XLA 编译标志

  • 设备配置通常由设备数量和设备拓扑捕获。当前,对于 GPU,拓扑仅包含 GPU 名称的字符串表示形式

  • 用于压缩已编译可执行文件的压缩算法

  • jax._src.cache_key.custom_hook() 生成的字符串。此函数可以重新分配给用户定义的函数,以便可以更改生成的字符串。默认情况下,此函数始终返回空字符串。

在多个节点上缓存#

首次运行程序(持久缓存为空/为空)时,所有进程都将进行编译,但只有全局通信组中排名为 0 的进程会写入持久缓存。在随后的运行中,所有进程都将尝试从持久缓存中读取,因此持久缓存必须位于共享文件系统(例如:NFS)或远程存储(例如:GFS)中。如果持久缓存是排名 0 的本地缓存,则排名 0 以外的所有进程将再次在随后的运行中进行编译,因为编译缓存未命中。

记录缓存活动#

检查持久编译缓存到底发生了什么可能有助于调试。以下是一些关于如何开始的建议。

用户可以通过将

import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"

放在脚本顶部来启用相关源文件的日志记录。

检查缓存未命中#

为了检查和理解为什么存在缓存未命中,JAX 包含一个配置标志,该标志启用所有缓存未命中(包括持久编译缓存未命中)及其解释的日志记录。虽然目前,这仅针对跟踪缓存未命中实现,但最终目标是解释所有缓存未命中。可以通过设置以下配置来启用它。

jax.config.update("jax_explain_cache_misses", True)

陷阱#

目前已发现一些陷阱

  • 目前,持久缓存不适用于具有主机回调的函数。在这种情况下,完全避免缓存。

    • 这是因为 HLO 包含指向回调的指针,并且即使计算和计算基础设施完全相同,它也会在每次运行时发生变化。

  • 目前,持久缓存不适用于使用实现其自己的 custom_partitioning 的原语的函数。

    • 函数的 HLO 包含指向 custom_partitioning 回调的指针,并且会导致相同计算在每次运行时产生不同的缓存键。

    • 在这种情况下,缓存仍然会继续,但每次都会生成不同的键,这使得缓存无效。

解决 custom_partitioning#

如前所述,编译缓存不适用于由实现 custom_partitioning 的原语组成的函数。但是,可以使用 shard_map 来规避那些实现它的原语的 custom_partitioning,并使编译缓存按预期工作

假设我们有一个函数 F,它实现一个层归一化,然后使用实现 custom_partitioning 的原语 LayerNorm 进行矩阵乘法

import jax

def F(x1, x2, gamma, beta):
   ln_out = LayerNorm(x1, gamma, beta)
   return ln_out @ x2

如果我们只是编译这个函数而不使用 shard_map,那么 layernorm_matmul_without_shard_map 的缓存键将在每次运行相同代码时都不同

layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)

但是,如果我们在 shard_map 中包装层归一化原语并定义一个执行相同计算的函数 G,那么 layernorm_matmul_with_shard_map 的缓存键将始终相同,即使 LayerNorm 实现了 custom_partitioning

import jax
from jax.experimental.shard_map import shard_map

def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
   ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta)
   return ln_out @ x2

ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)

请注意,必须将实现 custom_partitioning 的原语包装在 shard_map 中,才能解决此问题。仅将外部函数 F 包装在 shard_map 中是不够的。