jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[源代码]#
初始化 JAX 分布式系统。
调用
initialize()
为在多主机 GPU 和 Cloud TPU 上执行 JAX 做准备。必须在执行任何 JAX 计算之前调用initialize()
。JAX 分布式系统承担多个角色
它允许 JAX 进程互相发现并共享拓扑信息,
它执行健康检查,确保如果任何进程终止,所有进程都关闭,以及
它用于分布式检查点。
如果您正在使用 TPU、Slurm 或 Open MPI,则所有参数都是可选的:如果省略,它们将自动选择。
可以使用
cluster_detection_method
来选择检测这些分布式参数的特定方法。您可以将任何自动spec_detect_methods
传递给此参数,尽管在 TPU、Slurm 或 Open MPI 的情况下没有必要。对于其他 MPI 安装,如果您安装了功能正常的mpi4py
,则可以传递cluster_detection_method="mpi4py"
来引导所需的参数。否则,您必须将
coordinator_address
、num_processes
、process_id
和local_device_ids
参数提供给initialize()
。当提供所有四个参数时,将跳过集群环境自动检测。请注意:在某些系统上,特别是只能通过代理变量(例如 HTTP_PROXY、HTTPS_PROXY 等)访问外部网络的 HPC 集群上,调用
initialize()
可能会超时。您可能需要在应用程序启动之前取消设置这些变量。- 参数:
coordinator_address (
str
|None
|None
) – 进程0
的 IP 地址以及该进程应启动协调器服务的端口。端口的选择并不重要,只要该端口在协调器上可用,并且所有进程都同意该端口。仅在受支持的环境中可以为None
,在这种情况下它将自动选择。请注意,像localhost
或127.0.0.1
这样的特殊地址通常意味着程序将绑定到本地接口,并且不适合在多主机环境中运行。num_processes (
int
|None
|None
) – 进程数。仅在受支持的环境中可以为None
,在这种情况下它将自动选择。process_id (
int
|None
|None
) – 当前进程的 ID 号。集群中的process_id
值必须是一个密集的范围0
,1
, …,num_processes - 1
。仅在受支持的环境中可以为None
;如果为None
,它将自动选择。local_device_ids (
int
|Sequence
[int
] |None
|None
) – 将当前进程的可见设备限制为local_device_ids
。如果为None
,则默认为该进程可见的所有本地设备,除非通过 Slurm 和 Open MPI 在 GPU 上启动进程。在这种情况下,它将默认为每个进程一个设备。cluster_detection_method (
str
|None
|None
) – 一个可选的字符串,用于尝试自动检测分布式运行的配置。请注意,“mpi4py” 方法要求您在环境中安装可用的mpi4py
,并使用与 MPI 兼容的作业启动器(例如mpiexec
或mpirun
)启动应用程序。仍启用旧的自动检测选项 “ompi” (OMPI) 和 “slurm” (Slurm)。“deactivate” 会绕过自动集群检测。initialization_timeout (
int
) – 重试连接的时间段(以秒为单位)。如果初始化花费的时间超过指定的超时时间,则初始化将出错。默认为 300 秒,即 5 分钟。coordinator_bind_address (
str
|None
|None
) – 进程0
上的协调器服务应绑定的地址和端口。如果未指定,则默认绑定到与coordinator_address
相同的端口上的所有可用地址。在每个节点有多个网络接口的系统上,仅让协调器服务侦听一个地址/接口可能是不够的。
- 引发:
RuntimeError – 如果
initialize()
被多次调用,或者在后端已经初始化后被调用。
示例
假设有两个 GPU 进程,进程 0 是指定的协调器,地址为
10.0.0.1:1234
。要初始化 GPU 集群,请在其他任何操作之前运行以下命令。在进程 0 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
在进程 1 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)