GPU 内存分配

GPU 内存分配#

当第一次运行 JAX 操作时,JAX 会预先分配总 GPU 内存的 75%。预分配可以最大程度地减少分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。如果您的 JAX 进程因 OOM 失败,可以使用以下环境变量来覆盖默认行为

XLA_PYTHON_CLIENT_PREALLOCATE=false

这会禁用预分配行为。JAX 将改为根据需要分配 GPU 内存,这可能会减少整体内存使用量。但是,这种行为更容易导致 GPU 内存碎片,这意味着使用大部分可用 GPU 内存的 JAX 程序在禁用预分配后可能会出现 OOM。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果启用了预分配,这将使 JAX 预先分配总 GPU 内存的 XX%,而不是默认的 75%。降低预分配的量可以解决 JAX 程序启动时发生的 OOM 问题。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

这使得 JAX 根据需要精确分配所需的内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重复使用它的配置)。这非常慢,因此不建议一般使用,但可能有助于以最小的 GPU 内存占用量运行或调试 OOM 故障。

OOM 故障的常见原因#

同时运行多个 JAX 进程。

或者使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程提供适当数量的内存,或者设置 XLA_PYTHON_CLIENT_PREALLOCATE=false

同时运行 JAX 和 GPU TensorFlow。

TensorFlow 默认情况下也会预分配,因此这类似于同时运行多个 JAX 进程。

一种解决方案是使用仅 CPU 的 TensorFlow(例如,如果您只使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 防止 TensorFlow 使用 GPU

或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。 还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(TF1 中的 gpu_memory_fractionallow_growth,应该在传递给 tf.Sessiontf.ConfigProto 中设置。有关 TF2,请参阅 使用 GPU:限制 GPU 内存增长)。

在显示 GPU 上运行 JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE