分析设备内存#
注意
2023 年 5 月更新:我们建议使用 Tensorboard 分析 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer
选项卡,以了解更详细和更易理解的设备内存使用情况。
JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可以用于
找出在给定时间哪些数组和可执行文件在 GPU 内存中,或者
追踪内存泄漏。
安装#
JAX 设备内存分析器会发出可以使用 pprof (google/pprof) 解释的输出。首先安装 pprof
,请按照其安装说明进行操作。在编写本文时,安装 pprof
需要先安装 1.16+ 版本的 Go 和 Graphviz,然后运行
go install github.com/google/pprof@latest
这将 pprof
安装为 $GOPATH/bin/pprof
,其中 GOPATH
默认为 ~/go
。
注意
来自 google/pprof 的 pprof
版本与作为 gperftools
包一部分分发的同名旧工具不同。 gperftools
版本的 pprof
将不适用于 JAX。
了解 JAX 程序如何使用 GPU 或 TPU 内存#
设备内存分析器的一个常见用途是找出 JAX 程序为何使用大量 GPU 或 TPU 内存,例如在尝试调试内存不足问题时。
要将设备内存配置文件捕获到磁盘,请使用 jax.profiler.save_device_memory_profile()
。例如,考虑以下 Python 程序
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
如果我们首先运行上面的程序,然后执行
pprof --web memory.prof
pprof
会打开一个 Web 浏览器,其中包含设备内存配置文件的以下调用图格式的可视化
调用图是对每个活动的缓冲区分配时 Python 堆栈的可视化。例如,在这种特定情况下,可视化显示 func2
及其被调用者负责分配 76.30MB,其中 38.15MB 是在从 func1
到 func2
的调用内部分配的。 有关如何解释调用图可视化的更多信息,请参阅 pprof 文档。
使用 jax.jit()
编译的函数对于设备内存分析器是不透明的。也就是说,在 jit
编译的函数内部分配的任何内存都将归因于整个函数。
在示例中,对 block_until_ready()
的调用是为了确保 func2
在收集设备内存配置文件之前完成。有关更多详细信息,请参阅 异步调度。
调试内存泄漏#
我们还可以使用 JAX 设备内存分析器来跟踪内存泄漏,方法是使用 pprof
可视化在不同时间拍摄的两个设备内存配置文件之间的内存使用变化。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
如果我们仅可视化执行结束时的设备内存配置文件 (memory9.prof
),则可能不明显 anotherfunc
中循环的每次迭代都会累积更多的设备内存分配
pprof --web memory9.prof
afunction
内部的大但固定的分配主导着配置文件,但不会随着时间的推移而增长。
通过使用 pprof
的 --diff_base
功能 可视化跨循环迭代的内存使用变化,我们可以确定程序内存使用量随时间增加的原因
pprof --web --diff_base memory1.prof memory9.prof
可视化显示内存增长可归因于 anotherfunc
内部对 normal
的调用。