GPU 内存分配#
当第一次运行 JAX 操作时,JAX 会预先分配总 GPU 内存的 75%。预分配可以最大程度地减少分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。如果您的 JAX 进程因 OOM 失败,可以使用以下环境变量来覆盖默认行为
XLA_PYTHON_CLIENT_PREALLOCATE=false
这会禁用预分配行为。JAX 将改为根据需要分配 GPU 内存,这可能会减少整体内存使用量。但是,这种行为更容易导致 GPU 内存碎片,这意味着使用大部分可用 GPU 内存的 JAX 程序在禁用预分配后可能会出现 OOM。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
如果启用了预分配,这将使 JAX 预先分配总 GPU 内存的 XX%,而不是默认的 75%。降低预分配的量可以解决 JAX 程序启动时发生的 OOM 问题。
XLA_PYTHON_CLIENT_ALLOCATOR=platform
这使得 JAX 根据需要精确分配所需的内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重复使用它的配置)。这非常慢,因此不建议一般使用,但可能有助于以最小的 GPU 内存占用量运行或调试 OOM 故障。
OOM 故障的常见原因#
- 同时运行多个 JAX 进程。
或者使用
XLA_PYTHON_CLIENT_MEM_FRACTION
为每个进程提供适当数量的内存,或者设置XLA_PYTHON_CLIENT_PREALLOCATE=false
。- 同时运行 JAX 和 GPU TensorFlow。
TensorFlow 默认情况下也会预分配,因此这类似于同时运行多个 JAX 进程。
一种解决方案是使用仅 CPU 的 TensorFlow(例如,如果您只使用 TF 进行数据加载)。您可以使用命令
tf.config.experimental.set_visible_devices([], "GPU")
防止 TensorFlow 使用 GPU或者,使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。 还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(TF1 中的 gpu_memory_fraction 和 allow_growth,应该在传递给tf.Session
的tf.ConfigProto
中设置。有关 TF2,请参阅 使用 GPU:限制 GPU 内存增长)。- 在显示 GPU 上运行 JAX。
使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。