jax.experimental.mesh_utils 模块

内容

jax.experimental.mesh_utils 模块#

用于构建设备网格的实用程序。

API#

create_device_mesh(mesh_shape[, devices, ...])

为 jax.sharding.Mesh 创建一个高性能的设备网格。

create_hybrid_device_mesh(mesh_shape, ...[, ...])

为混合(例如,ICI 和 DCN)并行化创建设备网格。