多主机和多进程环境#
简介#
本指南介绍如何在 GPU 集群和 Cloud TPU Pod 等环境中使用 JAX,这些环境中的加速器分布在多个 CPU 主机或 JAX 进程上。我们将把这些称为“多进程”环境。
本指南特别关注如何在多进程设置中使用集体通信操作(例如 jax.lax.psum()
),尽管根据您的用例,其他通信方法也可能有用(例如 RPC、mpi4jax)。如果您不熟悉 JAX 的集体操作,我们建议从 并行编程简介 部分开始。JAX 中多进程环境的一个重要要求是在加速器之间建立直接通信链路,例如 Cloud TPU 的高速互连或 GPU 的 NCCL。这些链路允许集体操作跨越多个进程的加速器以高性能运行。
多进程编程模型#
关键概念
您必须每个主机至少运行一个 JAX 进程。
您应该使用
jax.distributed.initialize()
初始化集群。每个进程都有一个独特的本地设备集,它可以寻址。全局设备是所有进程中所有设备的集合。
使用标准的 JAX 并行 API,例如
jit()
(参见 并行编程入门 教程)和shard_map()
。jax.jit 仅接受全局形状的数组。shard_map 允许您降级到每个设备的形状。确保所有进程以相同的顺序运行相同的并行计算。
确保所有进程具有相同数量的本地设备。
确保所有设备相同(例如,全部为 V100 或全部为 H100)。
启动 JAX 进程#
与其他分布式系统(其中单个控制器节点管理许多工作节点)不同,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为 单程序多数据 (SPMD) 模型。通常,相同的 JAX Python 程序在每个进程中运行,每个进程的执行之间只有细微差别(例如,不同的进程将加载不同的输入数据)。此外,**您必须手动在每个主机上运行 JAX 程序!** JAX 不会自动从单个程序调用启动多个进程。
(需要多个进程的原因是本指南未作为笔记本提供——我们目前没有好的方法来从单个笔记本管理多个 Python 进程。)
初始化集群#
要初始化集群,您应该在每个进程的开始调用 jax.distributed.initialize()
。 jax.distributed.initialize()
必须在程序的早期调用,在执行任何 JAX 计算之前。
API jax.distributed.initialize()
接受多个参数,即
coordinator_address
:集群中进程 0 的 IP 地址,以及该进程上可用的端口。进程 0 将启动一个通过该 IP 地址和端口公开的 JAX 服务,集群中的其他进程将连接到该服务。coordinator_bind_address
:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与coordinator_address
相同的端口绑定到所有可用接口。num_processes
:集群中的进程数process_id
:此进程的 ID 号,范围为[0 .. num_processes)
。local_device_ids
:将当前进程的可见设备限制为local_device_ids
。
例如,在 GPU 上,典型用法是
import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=0)
在 Cloud TPU、Slurm 和 Open MPI 环境中,您可以简单地调用 jax.distributed.initialize()
而不带任何参数。将自动选择参数的默认值。在使用 Slurm 和 Open MPI 的 GPU 上运行时,假设每个 GPU 启动一个进程,即每个进程将只分配一个可见的本地设备。否则,假设每个主机启动一个进程,即每个进程将分配所有本地设备。仅当通过 mpirun
/mpiexec
启动 JAX 进程时,才会使用 Open MPI 自动初始化。
import jax
jax.distributed.initialize()
目前,在 TPU 上调用 jax.distributed.initialize()
是可选的,但建议这样做,因为它启用了其他检查点和健康检查功能。
本地设备与全局设备#
在我们开始从您的程序运行多进程计算之前,了解本地设备和全局设备之间的区别非常重要。
**进程的本地设备是它可以直接寻址并在其上启动计算的设备。**例如,在 GPU 集群中,每个主机只能在其直接连接的 GPU 上启动计算。在 Cloud TPU Pod 中,每个主机只能在其直接连接的 8 个 TPU 核心上启动计算(有关更多详细信息,请参阅 Cloud TPU 系统架构 文档)。您可以通过 jax.local_devices()
查看进程的本地设备。
**全局设备是所有进程中的设备。**计算可以跨进程跨设备扩展并执行集体操作,方法是通过设备之间的直接通信链路,只要每个进程在其本地设备上启动计算即可。您可以通过 jax.devices()
查看所有可用的全局设备。进程的本地设备始终是全局设备的子集。
运行多进程计算#
那么您如何实际运行涉及跨进程通信的计算呢?**使用与在单个进程中相同的并行评估 API!**
例如,shard_map()
可用于在多个进程中运行并行计算。(如果您还不熟悉如何使用 shard_map
在单个进程中的多个设备上运行,请查看 并行编程入门 教程。)从概念上讲,这可以被认为是在跨主机分片的单个数组上运行 pmap,其中每个主机“仅”看到其本地输入和输出分片。
这是一个多进程 pmap 运行的示例
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
>>> jax.device_count() # total number of accelerator devices in the cluster
32
>>> jax.local_device_count() # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
**所有进程以相同的顺序运行相同的跨进程计算非常重要。**在每个进程中运行相同的 JAX Python 程序通常就足够了。一些常见的可能导致计算顺序不同(即使运行相同的程序)的陷阱
进程将不同形状的输入传递给相同的并行函数会导致挂起或返回值不正确。只要不同形状的输入导致进程之间相同形状的每个设备数据分片,就可以安全地使用;例如,为了在每个进程的不同数量的本地设备上运行,传递不同的前导批次大小是可以的,但让每个进程将其批次填充到不同的最大示例长度是不可以的。
“最后一批”问题,其中在(训练)循环中调用并行函数,并且一个或多个进程比其他进程更早退出循环。这将导致其余进程挂起,等待已完成的进程开始计算。
基于集合的非确定性排序的条件会导致代码进程挂起。例如,在当前 Python 版本上迭代
set
或dict
在 Python 3.7 之前 可能会导致不同进程上的排序不同,即使插入顺序相同。