jax.profiler.device_memory_profile

jax.profiler.device_memory_profile#

jax.profiler.device_memory_profile(backend=None)[source]#

将 JAX 设备内存配置文件捕获为 pprof 格式的协议缓冲区。

设备内存配置文件是内存状态的快照,它描述了内存中存在的 JAX Array 和可执行对象,以及它们的分配位置。

有关如何使用设备内存分析器的更多信息,请参阅 分析设备内存

分析系统通过检测 JAX 的设备上分配来工作,为每个分配捕获一个 Python 堆栈跟踪。该检测始终启用;device_memory_profile() 提供了一个用于捕获它的 API。

输出 device_memory_profile() 是一个二进制协议缓冲区,可以使用 pprof 工具 进行解释和可视化。

参数:

backend (str | None | None) – 可选;要收集设备内存配置文件的 JAX 后端的名称。

返回:

包含二进制 pprof 格式协议缓冲区的字节字符串。

返回类型:

bytes