jax.experimental.mesh_utils 模块#

用于构建设备网格的实用工具。

API#

create_device_mesh(mesh_shape[, devices, ...])

为 jax.sharding.Mesh 创建一个高性能的设备网格。

create_hybrid_device_mesh(mesh_shape, ...[, ...])

为混合并行(例如,ICI 和 DCN)创建一个设备网格。