jax.experimental.multihost_utils.broadcast_one_to_all

jax.experimental.multihost_utils.broadcast_one_to_all#

jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[source]#

将数据从源主机(默认为主机 0)广播到所有其他主机。

参数:
  • in_tree (Any) – 数组的 pytree - 每个数组必须在所有主机上具有相同的形状。

  • is_source (bool | None | None) – 可选的布尔值,表示调用方是否是源。只有“源主机”才会为广播贡献数据。如果为 None,则使用主机 0。

返回值:

与 in_tree 匹配的 pytree,其中叶子现在都包含来自第一个主机的数据。

返回类型:

Any