多主机和多进程环境#

简介#

本指南介绍了如何在诸如 GPU 集群和 Cloud TPU Pod 等环境中,使用 JAX,这些环境中的加速器分布在多个 CPU 主机或 JAX 进程之间。我们将这些称为“多进程”环境。

本指南特别关注如何在多进程设置中使用集体通信操作(例如 jax.lax.psum()),尽管其他通信方法也可能在你的用例中很有用(例如 RPC、mpi4jax)。如果你还不熟悉 JAX 的集体操作,我们建议从并行编程入门部分开始。JAX 中多进程环境的一个重要要求是加速器之间的直接通信链路,例如 Cloud TPU 的高速互连或 GPU 的 NCCL。这些链接允许集体操作在多个进程的加速器上以高性能运行。

多进程编程模型#

核心概念

  • 每个主机必须至少运行一个 JAX 进程。

  • 您应该使用 jax.distributed.initialize() 初始化集群。

  • 每个进程都有一组不同的本地设备,可以寻址。全局设备是所有进程中所有设备的集合。

  • 使用标准的 JAX 并行 API,例如 jit()(请参阅 并行编程简介 教程)和 shard_map()。 jax.jit 仅接受全局形状的数组。shard_map 允许您降至每个设备的形状。

  • 确保所有进程以相同的顺序运行相同的并行计算。

  • 确保所有进程具有相同数量的本地设备。

  • 确保所有设备都相同(例如,所有 V100 或所有 H100)。

启动 JAX 进程#

与其他由单个控制器节点管理多个工作节点的分布式系统不同,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为 单程序多数据 (SPMD) 模型。通常,相同的 JAX Python 程序在每个进程中运行,每个进程的执行只有细微的差异(例如,不同的进程将加载不同的输入数据)。此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会自动从单个程序调用启动多个进程。

(需要多个进程是为什么本指南不作为笔记本提供的的原因 – 我们目前没有一个很好的方法从单个笔记本管理多个 Python 进程。)

初始化集群#

要初始化集群,您应该在每个进程的开始调用 jax.distributed.initialize()jax.distributed.initialize() 必须在程序早期调用,在执行任何 JAX 计算之前。

API jax.distributed.initialize() 接受几个参数,即

  • coordinator_address:集群中进程 0 的 IP 地址,以及该进程上可用的端口。进程 0 将启动通过该 IP 地址和端口公开的 JAX 服务,集群中的其他进程将连接到该服务。

  • coordinator_bind_address:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与 coordinator_address 相同的端口绑定到所有可用的接口。

  • num_processes:集群中的进程数

  • process_id:此进程的 ID 号,范围为 [0 .. num_processes)

  • local_device_ids:将当前进程的可见设备限制为 local_device_ids

例如,在 GPU 上,典型的用法是

import jax

jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0)

在 Cloud TPU、Slurm 和 Open MPI 环境中,您可以简单地调用 jax.distributed.initialize(),无需任何参数。将自动选择参数的默认值。当在 Slurm 和 Open MPI 上使用 GPU 运行时,假设每个 GPU 启动一个进程,即每个进程只分配一个可见的本地设备。否则,假设每个主机启动一个进程,即每个进程分配所有本地设备。仅当 JAX 进程通过 mpirun/mpiexec 启动时才使用 Open MPI 自动初始化。

import jax

jax.distributed.initialize()

目前,在 TPU 上调用 jax.distributed.initialize() 是可选的,但建议这样做,因为它启用了额外的检查点和健康检查功能。

本地设备与全局设备#

在我们从您的程序中运行多进程计算之前,理解本地设备和全局设备之间的区别非常重要。

进程的本地设备是它可以直接寻址并在其上启动计算的设备。例如,在 GPU 集群上,每个主机只能在直接连接的 GPU 上启动计算。在 Cloud TPU Pod 上,每个主机只能在直接连接到该主机的 8 个 TPU 核心上启动计算(有关更多详细信息,请参阅 Cloud TPU 系统架构 文档)。您可以通过 jax.local_devices() 查看进程的本地设备。

全局设备是所有进程中的设备。只要每个进程在其本地设备上启动计算,计算就可以跨进程的设备,并通过设备之间的直接通信链路执行集体操作。您可以通过 jax.devices() 查看所有可用的全局设备。进程的本地设备始终是全局设备的一个子集。

运行多进程计算#

那么,如何实际运行涉及跨进程通信的计算呢?使用与单个进程中相同的并行评估 API!

例如,shard_map() 可用于在多个进程之间运行并行计算。(如果您还不熟悉如何使用 shard_map 在单个进程中的多个设备上运行,请查看 并行编程简介 教程。)从概念上讲,这可以被认为是运行在跨主机分片的单个数组上的 pmap,其中每个主机“只看到”其本地的输入和输出分片。

这是一个正在运行的多进程 pmap 的示例

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

所有进程以相同顺序运行相同的跨进程计算非常重要。通常,在每个进程中运行相同的 JAX Python 程序就足够了。以下是一些需要注意的常见陷阱,这些陷阱可能会导致尽管运行相同的程序,计算顺序也不同

  • 进程将不同形状的输入传递给相同的并行函数可能会导致挂起或返回错误的值。只要它们在进程之间产生形状相同的每个设备数据分片,不同形状的输入是安全的;例如,传递不同的前导批量大小以便在每个进程的不同数量的本地设备上运行是可以的,但是让每个进程将其批量填充到不同的最大示例长度是不可以的。

  • “最后一个批量”问题,其中在(训练)循环中调用并行函数,并且一个或多个进程比其他进程更早退出循环。这将导致其余的进程挂起,等待已完成的进程开始计算。

  • 基于集合的非确定性排序的条件会导致代码进程挂起。例如,在当前 Python 版本上迭代 setdict 在 Python 3.7 之前 可能会导致不同进程上的顺序不同,即使插入顺序相同。