分布式数据加载#
本高级指南演示了如何在运行 JAX 时执行分布式数据加载 — 当您在 多主机或多进程环境 中运行 JAX,并且 JAX 计算所需的数据分布在多个进程中时。本文档涵盖了如何考虑分布式数据加载的整体方法,以及如何将其应用于数据并行(更简单)和模型并行(更复杂)工作负载。
与替代方案(例如:1) 在单个进程中加载完整全局数据,将其拆分并通过 RPC 将所需部分发送到其他进程;以及 2) 在所有进程中加载完整全局数据,并且每个进程仅使用所需部分)相比,分布式数据加载通常效率更高(数据分布在多个进程中),但也更复杂。加载完整全局数据通常更简单但成本更高。例如,在机器学习中,训练循环可能会在等待数据时被阻塞,并且每个进程都会使用额外的网络带宽。
注意
使用分布式数据加载时,重要的是每个设备(例如,每个 GPU 或 TPU)都能访问运行计算所需的输入数据分片。这通常使得分布式数据加载更复杂,并且难以正确实现(与上面描述的替代方案相比)。如果错误的数据分片最终位于错误的设备上,计算仍然可以运行而不会出错,因为计算无法知道输入数据“应该”是什么。但是,最终结果通常会不正确,因为输入数据与预期不同。
加载 jax.Array
的通用方法#
考虑创建一个单个jax.Array
的案例,该数组来自非 JAX 生成的原始数据。这些概念适用于加载批量数据记录之外的情况,例如任何非 JAX 计算直接产生的多进程jax.Array
。例如:1)从检查点加载模型权重;或 2)加载大型空间分片图像。
每个jax.Array
都有一个关联的Sharding
,它描述了每个全局设备需要全局数据的哪个分片。当您从头创建一个jax.Array
时,您也需要创建它的Sharding
。这是 JAX 了解数据如何在设备之间布局的方式。您可以创建任何您想要的Sharding
。在实践中,您通常会根据您正在实现的并行策略类型选择Sharding
(您将在本指南的后面更详细地了解数据和模型并行性)。您还可以根据每个进程中原始数据生成的方式选择Sharding
。
定义Sharding
后,您可以使用addressable_devices()
提供当前进程内加载数据所需的设备列表。(注意:“可寻址设备”是“本地设备”的更通用版本。目标是确保每个进程的数据加载器向该进程的所有本地设备提供正确的数据。
示例#
例如,考虑一个(64, 128)
jax.Array
,您需要将其跨 4 个进程(每个进程 2 个设备,总共 8 个设备)进行分片。这将导致 8 个唯一的数据分片,每个设备一个。分片此jax.Array
的方法有很多种。您可以沿着jax.Array
的第二维进行 1D 分片,为每个设备提供一个(64, 16)
的分片,如下所示
在上图中,每个数据分片都有自己的颜色,表示哪个进程需要加载该分片。例如,假设进程0
的 2 个设备包含分片A
和B
,对应于全局数据的第一个(64, 32)
部分。
您可以选择分片到设备的不同分布。例如
这是另一个例子——2D 分片
无论jax.Array
如何分片,您都必须确保每个进程的数据加载器都提供/加载全局数据的所需分片。实现这一点有几种高级方法:1)在每个进程中加载全局数据;2)使用每个设备的数据管道;3)使用合并的每个进程的数据管道;4)以某种方便的方式加载数据,然后在计算内部重新分片。
选项 1:在每个进程中加载全局数据#
使用此选项,每个进程
加载所需完整的值;以及
仅将所需的分片传输到该进程的本地设备。
这并不是一种有效的分布式数据加载方法,因为每个进程都会丢弃其本地设备不需要的数据,并且摄取的总数据量可能高于必要。但是此选项有效且相对易于实现,而性能开销对于某些工作负载可能是可以接受的(例如,如果全局数据很小)。
选项 2:使用每个设备的数据管道#
在此选项中,每个进程为其每个本地设备设置一个数据加载器(即,每个设备都获得自己的数据加载器,仅用于其所需的数据分片)。
这在加载的数据方面效率很高。有时考虑每个设备而不是所有进程的本地设备一次也可能更简单(请参阅下面的选项 3:使用合并的每个进程的数据管道)。但是,拥有多个并发数据加载器有时会导致性能问题。
选项 3:使用合并的每个进程的数据管道#
如果您选择此选项,则每个进程
设置一个单一的数据加载器,加载其所有本地设备所需的数据;然后
在传输到每个本地设备之前对本地数据进行分片。
这是执行分布式加载的最有效方法。但是,它也是最复杂的方法,因为需要逻辑来确定每个设备需要哪些数据,并创建一个仅加载所有这些数据(理想情况下,不加载任何其他额外数据)的单个数据加载。
选项 4:以某种方便的方式加载数据,在计算内部重新分片#
此选项更难以解释,但通常比上述选项(从 1 到 3)更容易实现。
想象一个场景,其中很难或不可能设置数据加载器来加载您需要的确切数据,无论是针对每个设备还是每个进程的加载器。但是,仍然可能为每个进程设置一个数据加载器,以加载1 / num_processes
的数据,只是分片不正确。
然后,继续使用您之前提到的 2D 示例分片,假设每个进程更容易加载数据的一列
然后,您可以创建一个具有Sharding
的jax.Array
,该分片表示每列数据,将其直接传递到计算中,并使用jax.lax.with_sharding_constraint()
将列分片输入立即重新分片到所需的分片。并且由于数据在计算内部重新分片,因此它将在加速器通信链路上重新分片(例如,TPU ICI 或 NVLink)。
此选项 4 与选项 3(使用合并的每个进程的数据管道)具有类似的优势
每个进程仍然拥有一个单一的数据加载器;以及
全局数据在所有进程中仅加载一次;以及
全局数据还具有在如何加载数据方面提供更多灵活性的额外优势。
但是,此方法使用加速器互连带宽来执行重新分片,这可能会减慢某些工作负载的速度。选项 4 还要求输入数据表示为一个单独的Sharding
,除了目标Sharding
之外。
复制#
复制描述了一个过程,其中多个设备具有相同的数据分片。上面提到的通用选项(选项 1 到 4)仍然适用于复制。唯一的区别是一些进程最终可能会加载相同的数据分片。本节描述了完整复制和部分复制。
完整复制#
完整复制是一个过程,其中所有设备都具有数据的完整副本(即,数据“分片”是整个数组值)。
在下面的示例中,由于总共有 8 个设备(每个进程 2 个),因此您最终将获得 8 个完整数据的副本。每个数据的副本都是未分片的,即副本位于单个设备上
部分复制#
部分复制描述了一个过程,其中有多个数据的副本,并且每个副本都跨多个设备进行分片。对于给定的数组值,通常有许多可能的方法来执行部分复制(注意:对于给定的数组形状,始终只有一个完全复制的Sharding
)。
以下是两个可能的示例。
在下面的第一个示例中,每个副本跨进程的两个本地设备进行分片,总共 4 个副本。这意味着每个进程都需要加载完整的全局数据,因为其本地设备将具有数据的完整副本。
在下面的第二个示例中,每个副本仍然跨两个设备进行分片,但每个设备对分布在两个不同的进程中。进程0
(粉红色)和进程1
(黄色)都需要加载数据的仅第一行,而进程2
(绿色)和进程3
(蓝色)都需要加载数据的仅第二行
现在您已经了解了创建jax.Array
的高级选项,让我们将它们应用于 ML 应用程序的数据加载。
数据并行#
在纯数据并行(无模型并行)中
您在每个设备上复制模型;以及
每个模型副本(即,每个设备)接收不同的每个副本批次的数据。
当将输入数据表示为单个jax.Array
时,该数组包含此步骤中所有副本的数据(这称为全局批次),其中jax.Array
的每个分片包含单个每个副本的批次。您可以将其表示为跨所有设备的 1D 分片(请查看下面的示例)——换句话说,全局批次由所有每个副本的批次沿批次轴连接在一起组成。
应用此框架,您可以得出结论,进程0
应该获取全局批次的第一四分之一(8 个中的 2 个),而进程1
应该获取第二四分之一,依此类推。
但是,您如何知道第一四分之一是什么?以及如何确保进程0
获取第一四分之一?幸运的是,关于数据并行性有一个非常重要的技巧,这意味着您不必回答这些问题,并且使整个设置更简单。
数据并行性的重要技巧#
技巧在于您不需要关心每个副本的批次落在哪个副本上。因此,哪个进程加载批次并不重要。原因是由于每个设备都对应于执行相同操作的模型副本,因此在全局批次中哪个设备获取哪个每个副本的批次并不重要。
这意味着您可以自由地重新排列全局批次中的每个副本的批次。换句话说,您可以自由地随机化每个设备获取的数据分片。
例如
通常,重新排列jax.Array
的数据分片(如上所示)不是一个好主意——您实际上是在置换jax.Array
的值!但是,对于数据并行性,全局批次顺序没有意义,您可以自由地重新排列全局批次中的每个副本的批次,如之前所述。
这简化了数据加载,因为它意味着每个设备只需要一个独立的每个副本的批次流,这可以通过为每个进程创建一个独立的管道并将生成的每个进程的批次分成每个副本的批次,在大多数数据加载器中轻松实现。
这是选项 2:合并每个进程的数据管道的一个实例。您还可以使用其他选项(例如 0、1 和 3,这些选项在本文档的前面部分介绍过),但此选项相对简单且高效。
以下是如何使用 tf.data 实现此设置的示例
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
[np.ones((16, 3)) * i for i in range(100)])
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch, jax.local_device_count())
# Thanks to the very important trick about data parallelism, no need to care what
# order the devices appear in the sharding.
sharding = jax.sharding.PositionalSharding(jax.devices())
# PositionalSharding must have same rank as data being sharded.
sharding = sharding.reshape((jax.device_count(),) +
(1,) * (per_process_batch.ndim - 1))
global_batch_size = per_replica_batch_size * jax.device_count()
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
global_batch_array = jax.make_array_from_single_device_arrays(
global_batch_shape, sharding,
# Thanks again to the very important trick, no need to care which device gets
# which per-replica batch.
arrays=[jax.device_put(batch, device)
for batch, device
in zip(per_replica_batches, sharding.addressable_devices)])
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
数据 + 模型并行#
在模型并行中,您将每个模型副本跨多个设备分片。如果您使用纯模型并行(无数据并行)
只有一个模型副本跨所有设备分片;并且
数据(通常)在所有设备上完全复制。
本指南考虑了您同时使用数据和模型并行的情况
您将多个模型副本中的每一个跨多个设备分片;并且
您将数据部分复制到每个模型副本上——同一模型副本中的每个设备都获取相同的每个副本的批次,而跨模型副本的设备获取不同的每个副本的批次。
进程内的模型并行#
出于数据加载的目的,最简单的办法是在单个进程的本地设备中对每个模型副本进行分片。
对于此示例,让我们切换到每个进程 4 个设备的 2 个进程(而不是每个进程 2 个设备的 4 个进程)。考虑每个模型副本跨单个进程的 2 个本地设备分片的情况。这导致每个进程 2 个模型副本,总共 4 个模型副本,如下所示
在这里,输入数据再次表示为单个jax.Array
,具有 1D 分片,其中每个分片都是一个每个副本的批次,但有一个例外
与纯数据并行的情况不同,您引入了部分复制并创建了 1D 分片全局批次的 2 个副本。
这是因为每个模型副本都由 2 个设备组成,每个设备都需要每个副本的批次的副本。
将每个模型副本保留在单个进程中可以简化操作,因为您可以重用上面描述的纯数据并行设置,只是您还需要复制每个副本的批次
注意
将每个副本的批次复制到正确的设备上也至关重要!虽然关于数据并行性的非常重要的技巧意味着您不关心哪个批次最终落在哪个副本上,但您确实关心单个副本只获取单个批次。
例如,这是可以的
但是,如果您不小心将每个批次加载到哪个本地设备上,您可能会意外地创建未复制的数据,即使Sharding
(和并行策略)表明数据已复制
如果您意外地创建了一个jax.Array
,其中包含应该在单个进程中复制但未复制的数据,JAX 将引发错误(但这并非总是适用于跨进程的模型并行;请参阅下一节)。
以下是如何使用tf.data
实现每个进程模型并行和数据并行的示例
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
# (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = (per_process_batch_size //
num_model_replicas_per_process)
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch,
num_model_replicas_per_process)
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
num_processes = len(set(d.process_index for d in replica_devices))
assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_size = per_replica_batch_size * num_model_replicas_total
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
# Create the final jax.Array using jax.make_array_from_callback. The callback
# will be called for each local device, and passed the N-D numpy-style index
# that describes what shard of the global data that device should receive.
#
# You don't need care exactly which index is passed in due to the very important data
# parallelism, but you do use the index argument to make sure you replicate each
# per-replica batch correctly -- the `index` argument will be the same for
# devices in the same model replica, and different for devices in different
# model replicas.
index_to_batch = {}
def callback(index: tuple[slice, ...]) -> np.ndarray:
# Python `slice` objects aren't hashable, so manually create dict key.
index_key = tuple((slice_.start, slice_.stop) for slice_ in index)
if index_key not in index_to_batch:
# You don't care which per-replica batch goes to which replica, just take the
# next unused one.
index_to_batch[index_key] = per_replica_batches[len(index_to_batch)]
return index_to_batch[index_key]
global_batch_array = jax.make_array_from_callback(
global_batch_shape, sharding, callback)
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
跨进程的模型并行#
当模型副本分布在多个进程中时,情况可能会变得更加有趣,原因可能是
因为单个副本无法容纳在一个进程中;或者
因为设备分配没有这样设置。
例如,回到之前每个进程 2 个设备的 4 个进程的设置,如果您将设备分配给副本,如下所示
这与之前的每个进程模型并行示例具有相同的并行策略——4 个模型副本,每个副本跨 2 个设备分片。唯一的区别是设备分配——每个副本的两个设备分布在不同的进程中,并且每个进程只负责每个副本的批次的一个副本(但对于两个副本)。
像这样将模型副本拆分到多个进程中可能看起来像是一件任意且不必要的事情(在本例中,这可以说确实如此),但实际部署最终可能会使用这种设备分配来最大程度地利用设备之间的通信链路。
数据加载现在变得更加复杂,因为需要在进程之间进行一些额外的协调。在纯数据并行和每个进程模型并行的情况下,每个进程加载唯一的数据流就足够了。现在,某些进程必须加载相同的数据,而某些进程必须加载不同的数据。在上面的示例中,进程0
和2
(分别为粉红色和绿色)必须加载相同的 2 个每个副本的批次,而进程1
和3
(分别为黄色和蓝色)也必须加载相同的 2 个每个副本的批次(但与进程0
和2
的批次不同)。
此外,重要的是每个进程不要混淆其 2 个每个副本的批次。虽然您不关心哪个批次落在哪个副本上(关于数据并行性的非常重要的技巧),但您需要确保副本中的所有设备都获取相同的批次。例如,这将是不好的
注意
截至 2023 年 8 月,JAX 无法检测跨进程的jax.Array
分片是否应该被复制但未被复制,并且在运行计算时会产生错误的结果。因此,请注意不要这样做!
要获取每个设备上正确的每个副本的批次,您需要将全局输入数据表示为以下jax.Array