jax.make_array_from_process_local_data

jax.make_array_from_process_local_data#

jax.make_array_from_process_local_data(sharding, local_data, global_shape=None)[source]#

使用进程中可用的数据创建分布式张量。

此函数是 make_array_from_callback 的常见特殊情况。它假设数据在进程中可用,并负责索引整理。

最常见的情况是分片跨批次维度分片,并且每个主机只加载其对应的子批次。此函数也支持更一般的情况,例如混合多主机和多轴复制和分片,但您需要正确计算进程本地数据的尺寸和内容以满足分片约束。

特别是,如果两个主机是副本,则 host_local_data 也应该相同。

global_shape 是可选的。如果未提供,它将从 local_data 和 sharding 推断出来,假设每个主机仅代表其自身数据以进行统一分片。如果分片是非均匀的(参见下文说明),则会引发异常。

显式设置 `global_shape` 允许更细粒度的控制,并与非均匀分片一起使用。`global_shape` 的每个维度必须与 `host_local_data` 匹配,或者与分片的推断全局形状匹配(在这种情况下,它等效于将其设置为 `None`,但更显式)。

例如,如果维度 `i` 被完全分片,那么这个大小将是 `per_device_shape[i] * jax.local_device_count()`。每个设备将被映射到 `local_data` 数组的本地切片中。例如,如果给定的进程地址切片为 (8, 12) 和 (24, 28),那么这些切片将被映射到 `local_data` 的 (0, 4) 和 (4, 8)。

对于 `global_shapes` 与 `local_shape` 匹配的每个维度,每个设备将在 `local_data` 中查找切片。例如,如果 `global_shape == local_data.shape`,则假设本地数据是将被分片到设备中的实际目标数组。

如果 `global_shape` 与 `local_data.shape` 相同,则数据在所有主机上必须相同。

示例

>>> from jax.sharding import PartitionSpec as P
>>> mesh_rows = 2
>>> mesh_cols =  jax.device_count() // 2
...
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
>>> rows_per_device = 2
>>> feature_length = 32
>>> per_device_shape = (rows_per_device, feature_length)
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
>>> per_host_data = per_host_generator()  # replace with your own per-host data pipeline that outputs numpy arrays
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
>>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape)
...
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape

注意:虽然大多数分片是均匀的,但可以设计一个奇特的共享网格,其中每个进程的设备将在某些维度上以非网格状模式排列,或者索引非平凡地重叠。这种分片在这些维度上被称为“非均匀”。在这种情况下,沿着这些方向的全局形状必须与本地形状匹配,因为没有有意义的方法以非重叠的方式表示所有需要的每个进程数据。例如,对于全局形状 4x4,如果分片如下所示

0123 2103 4675 4567

有 4 个进程,分别包含设备 (0,1)、(2, 3)、(4, 5)、(6, 7)。然后每个主机的看起来如下

xx.. ..xx …. …. .xx. x..x …. …. …. …. x..x .xx. …. …. xx.. ..xx

分片在行上是均匀的(每个主机都需要第 1-2 行或第 3-4 行),在列上是非均匀的(主机需要重叠但并不匹配的列集)。因此,本地数据必须具有 2x4 或 4x4 的形状,即使每个主机可以潜在地适合 2x2 形状。在这种情况下,用户必须显式提供 `global_shape`,对于 `local_shape=(2, 4)`,可能的有效全局形状是 (2, 4) 和 (4, 4)。

另一方面,对于分片

0213 x.x. .x.x. …. …. 0213 x.x. .x.x. …. …. 4657 …. …. .x.x x.x. 4657 …. …. .x.x x.x.

对于 `local_shape=(2, 2)`,此函数可以接受 2x2、2x4、4x2 和 4x4 全局形状的选择。在这种情况下,将 `global_shape` 设置为 `None` 等效于将其设置为 (4, 4)。

参数:
  • sharding (Sharding) – 全局张量的分片。

  • local_data (np.ndarray) – 主机上的数据将被放置在本地设备上。每个维度都应与 `global_shape` 匹配,或者与 `num_addressable_indices(dim)` 匹配。

  • global_shape (Shape | None | None) – 全局张量的目标形状。如果为 `None`,则将从 `local_data` 和分片中推断。

返回值:

将具有 `sharding=sharding` 且形状为 `global_shape` 的张量。

返回值类型:

ArrayImpl