多主机和多进程环境#
简介#
本指南介绍如何在 GPU 集群和 Cloud TPU Pod 等环境中,在多个 CPU 主机或 JAX 进程之间分布加速器时使用 JAX。 我们将这些称为“多进程”环境。
本指南专门介绍如何在多进程环境中使用集体通信操作(例如 jax.lax.psum()
),尽管其他通信方法也可能在您的用例中很有用(例如,RPC、mpi4jax)。如果您还不熟悉 JAX 的集体操作,我们建议从并行编程简介部分开始。 JAX 中多进程环境的一个重要要求是加速器之间的直接通信链接,例如 Cloud TPU 的高速互连或 GPU 的 NCCL。这些链接允许集体操作跨多个进程的加速器以高性能运行。
多进程编程模型#
核心概念
每个主机必须至少运行一个 JAX 进程。
您应该使用
jax.distributed.initialize()
初始化集群。每个进程都有其可以寻址的不同的本地设备集。全局设备是所有进程中所有设备的集合。
使用标准的 JAX 并行 API,例如
jit()
(请参阅并行编程简介教程)和shard_map()
。 jax.jit 仅接受全局形状的数组。 shard_map 允许您降到每个设备的形状。确保所有进程以相同的顺序运行相同的并行计算。
确保所有进程都具有相同数量的本地设备。
确保所有设备都相同(例如,所有 V100 或所有 H100)。
启动 JAX 进程#
与其他由单个控制器节点管理多个工作节点的分布式系统不同,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时被称为单程序多数据 (SPMD)模型。 通常,相同的 JAX Python 程序在每个进程中运行,每个进程的执行之间只有细微的差异(例如,不同的进程将加载不同的输入数据)。 此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会自动从单个程序调用启动多个进程。
(需要多个进程是为什么本指南不以笔记本的形式提供的原因 – 我们目前没有好的方法从单个笔记本管理多个 Python 进程。)
初始化集群#
要初始化集群,您应该在每个进程开始时调用 jax.distributed.initialize()
。jax.distributed.initialize()
必须在程序早期调用,在执行任何 JAX 计算之前。
API jax.distributed.initialize()
接受几个参数,即
coordinator_address
:集群中进程 0 的 IP 地址,以及该进程上可用的端口。进程 0 将启动一个通过该 IP 地址和端口公开的 JAX 服务,集群中的其他进程将连接到该服务。coordinator_bind_address
:JAX 服务在集群中进程 0 上绑定的 IP 地址和端口。默认情况下,它将使用与coordinator_address
相同的端口绑定到所有可用接口。num_processes
:集群中的进程数process_id
:此进程的 ID 号,范围为[0 .. num_processes)
。local_device_ids
:将当前进程的可见设备限制为local_device_ids
。
例如,在 GPU 上,典型用法是
import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=0)
在 Cloud TPU、Slurm 和 Open MPI 环境中,您可以直接调用 jax.distributed.initialize()
,而无需任何参数。 将自动选择参数的默认值。 在使用 Slurm 和 Open MPI 在 GPU 上运行时,假定每个 GPU 启动一个进程,即每个进程将仅分配一个可见的本地设备。 否则,假定每个主机启动一个进程,即每个进程将被分配所有本地设备。 仅当 JAX 进程通过 mpirun
/mpiexec
启动时,才使用 Open MPI 自动初始化。
import jax
jax.distributed.initialize()
目前,在 TPU 上调用 jax.distributed.initialize()
是可选的,但建议使用,因为它启用了额外的检查点和健康检查功能。
本地设备与全局设备#
在我们开始从您的程序运行多进程计算之前,了解本地设备和全局设备之间的区别非常重要。
一个进程的本地设备是它可以直接寻址并在其上启动计算的设备。例如,在 GPU 集群上,每个主机只能在其直接连接的 GPU 上启动计算。在 Cloud TPU Pod 上,每个主机只能在其直接连接的 8 个 TPU 核心上启动计算(有关更多详细信息,请参阅 Cloud TPU 系统架构文档)。您可以通过 jax.local_devices()
查看进程的本地设备。
全局设备是所有进程中的设备。计算可以跨进程跨设备,并通过设备之间的直接通信链路执行集体操作,只要每个进程都在其本地设备上启动计算即可。您可以通过 jax.devices()
查看所有可用的全局设备。进程的本地设备始终是全局设备的子集。
运行多进程计算#
那么,您实际上如何运行涉及跨进程通信的计算呢?使用您在单个进程中使用的相同并行评估 API!
例如,shard_map()
可用于跨多个进程运行并行计算。(如果您还不熟悉如何使用 shard_map
在单个进程中的多个设备上运行,请查看并行编程简介教程。)从概念上讲,这可以被认为是跨主机分片的单个数组运行 pmap,其中每个主机“看到”的只是其输入的本地分片和输出。
这是一个多进程 pmap 的实际示例
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
>>> jax.device_count() # total number of accelerator devices in the cluster
32
>>> jax.local_device_count() # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
所有进程都以相同的顺序运行相同的跨进程计算非常重要。在每个进程中运行相同的 JAX Python 程序通常就足够了。 需要注意一些可能导致计算顺序不同的常见陷阱,尽管它们运行的是相同的程序。
进程将形状不同的输入传递给同一个并行函数可能会导致挂起或返回错误的值。 只要它们导致跨进程的每个设备的数据分片形状相同,则形状不同的输入是安全的;例如,为了在每个进程的不同数量的本地设备上运行,传入不同的前导批大小是没问题的,但是让每个进程将其批次填充到不同的最大示例长度则不行。
“最后批次”问题,即在(训练)循环中调用并行函数,并且一个或多个进程比其余进程更早退出循环。 这将导致其余进程挂起,等待已完成的进程开始计算。
基于集合的非确定性排序的条件可能会导致代码进程挂起。 例如,在当前 Python 版本上迭代
set
或在 Python 3.7 之前迭代dict
,即使插入顺序相同,也可能导致不同进程上的排序不同。