jax.experimental.mesh_utils.create_device_mesh#
- jax.experimental.mesh_utils.create_device_mesh(mesh_shape, devices=None, *, contiguous_submeshes=False, allow_split_physical_axes=False)[源代码]#
为 jax.sharding.Mesh 创建一个高性能的设备网格。
- 参数:
mesh_shape (Sequence[int]) – 逻辑网格的形状,按照网络强度递增的顺序排列,例如 [副本, 数据, 模型],其中模型具有最多的网络通信需求。
devices (Sequence[Any] | None | None) – 可选参数,用于构造网格的设备。默认为 jax.devices()。
contiguous_submeshes (bool) – 如果为 True,此函数将尝试创建一个网格,其中每个进程的本地设备形成一个连续的子网格。如果此函数无法生成合适的网格,则会引发 ValueError。在引入 jax.Array 之前,有时需要此设置以确保非参差的本地数组;如果使用 jax.Arrays,最好将其设置为 False。
allow_split_physical_axes (bool) – 如果为 True,必要时我们会拆分物理轴以生成所需的设备网格。
- 引发:
ValueError – 如果设备数量不等于 mesh_shape 的乘积。
- 返回:
一个 np.ndarray 形式的 JAX 设备,其形状为 mesh_shape,可以馈送到 jax.sharding.Mesh 中,具有良好的集体性能。
- 返回类型:
np.ndarray