jax.experimental.multihost_utils 模块

jax.experimental.multihost_utils 模块#

用于跨多个主机同步和通信的实用程序。

多主机实用程序 API 参考#

broadcast_one_to_all(in_tree[, is_source])

将数据从源主机(默认主机 0)广播到所有其他主机。

sync_global_devices(name)

在所有主机/设备之间创建一个屏障。

process_allgather(in_tree[, tiled])

从所有进程中收集数据。

assert_equal(in_tree[, fail_message])

验证所有主机是否具有相同的树形值。

host_local_array_to_global_array(...)

将主机本地值转换为全局分片的 jax.Array。

global_array_to_host_local_array(...)

将全局 jax.Array 转换为主机本地 jax.Array