jax.distributed 模块

jax.distributed 模块#

initialize([coordinator_address, ...])

初始化 JAX 分布式系统。

shutdown()

关闭分布式系统。