jax.experimental.mesh_utils.create_hybrid_device_mesh#
- jax.experimental.mesh_utils.create_hybrid_device_mesh(mesh_shape, dcn_mesh_shape, devices=None, *, process_is_granule=False, should_sort_granules_by_key=True, allow_split_physical_axes=False)[source]#
为混合(例如,ICI 和 DCN)并行性创建设备网格。
- 参数::
mesh_shape (Sequence[int]) – 较快/内部网络的逻辑网格形状,按网络强度递增排序,例如 [replica, data, mdl],其中 mdl 的网络通信需求最高。
dcn_mesh_shape (Sequence[int]) – 较慢/外部网络的逻辑网格形状,与 mesh_shape 的顺序相同。
devices (Sequence[Any] | None | None) – 可选,要为其构建网格的设备。默认为 jax.devices()。
process_is_granule (bool) – 如果为 True,则此函数将把进程视为较慢/外部网络的单元。否则,它将查找设备上的 slice_index 属性并使用切片作为单元。启用此功能是为了作为未设置 slice_index 的平台的回退。
should_sort_granules_by_key (bool) – 是否应该根据颗粒键对设备颗粒进行排序,具体取决于 process_is_granule,可以是切片索引或进程索引。
allow_split_physical_axes (bool) – 如果为 True,我们将根据需要拆分物理轴以生成所需的设备网格。
- 引发:
ValueError – 如果devices所属的切片数量与dcn_mesh_shape的乘积不相等,或者任何单个切片所属的设备数量与mesh_shape的乘积不相等。
- 返回值:
一个形状为 mesh_shape * dcn_mesh_shape 的 JAX 设备 np.ndarray,可以将其馈送到 jax.sharding.Mesh 中以进行混合并行计算。
- 返回类型:
np.ndarray