GPU 内存分配#

当第一个 JAX 操作运行时,JAX 将预先分配 75% 的 GPU 总内存。 预分配可以最大限度地减少分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。 如果您的 JAX 进程因 OOM 而失败,可以使用以下环境变量来覆盖默认行为

XLA_PYTHON_CLIENT_PREALLOCATE=false

这将禁用预分配行为。 JAX 将改为根据需要分配 GPU 内存,从而可能减少总体内存使用量。但是,这种行为更容易发生 GPU 内存碎片,这意味着禁用预分配后,使用大部分可用 GPU 内存的 JAX 程序可能会出现 OOM。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果启用预分配,这将使 JAX 预分配 XX% 的 GPU 总内存,而不是默认的 75%。 降低预分配的量可以修复 JAX 程序启动时发生的 OOM。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

这会使 JAX 根据需求分配所需的内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用的配置)。 这非常慢,因此不建议用于一般用途,但可能有助于以尽可能小的 GPU 内存占用或调试 OOM 故障来运行。

OOM 失败的常见原因#

同时运行多个 JAX 进程。

可以使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程分配适当的内存量,或者设置 XLA_PYTHON_CLIENT_PREALLOCATE=false

同时运行 JAX 和 GPU TensorFlow。

TensorFlow 默认也会进行预分配,这与同时运行多个 JAX 进程类似。

一种解决方案是使用仅限 CPU 的 TensorFlow(例如,如果您仅使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 来阻止 TensorFlow 使用 GPU

或者,可以使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。 还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(在 TF1 中,gpu_memory_fractionallow_growth 应该在传递给 tf.Sessiontf.ConfigProto 中设置。 有关 TF2 的信息,请参见 使用 GPU:限制 GPU 内存增长)。

在显示 GPU 上运行 JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE

禁用重物化 HLO pass

有时,禁用自动重物化 HLO pass 有利于避免编译器做出较差的重物化选择。 可以通过设置 jax.config.update('enable_remat_opt_pass', True)jax.config.update('enable_remat_opt_pass', False) 来启用/禁用该 pass。 启用或禁用自动重物化 pass 会在计算和内存之间产生不同的权衡。 但是请注意,该算法是基础的,您通常可以通过禁用自动重物化 pass 并使用 jax.remat API 手动执行来获得更好的计算和内存之间的权衡。

实验性功能#

此处的特性是实验性的,必须谨慎尝试。

TF_GPU_ALLOCATOR=cuda_malloc_async

这会将 XLA 自己的 BFC 内存分配器替换为 cudaMallocAsync。 这将删除大的固定预分配,并使用一个增长的内存池。 预期的好处是不需要设置 XLA_PYTHON_CLIENT_MEM_FRACTION

风险在于

  • 内存碎片情况不同,因此如果您接近极限,由于碎片导致的 OOM 情况会有所不同。

  • 分配时间不会在开始时全部付出,而是在需要增加内存池时产生。 因此,您可能会在开始时体验到较少的速度稳定性,并且对于基准测试,忽略前几次迭代会更加重要。

可以通过预分配大量的内存块来缓解风险,并且仍然可以享受增长的内存池的好处。 这可以通过 TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N 完成。 如果 N 是 -1,则将预分配与默认分配相同的内存量。 否则,它是您要预分配的大小(以字节为单位)。