jax.experimental.multihost_utils.process_allgather#
- jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[source]#
从所有进程中收集数据。
- 参数:
in_tree (Any) – 数组的 pytree - 每个数组在所有主机上的形状都_必须_相同。
tiled (bool) – 是否堆叠或连接输出。默认为 False,即在索引 0 处沿新的位置轴堆叠。
- 返回值:
- NumPy 数组的 Pytrees。
如果输入是非完全可寻址的 jax.Array,则数据将完全复制。
如果输入是 NumPy 数组或完全可寻址的 jax.Array,则输出形状取决于 tiled 参数。如果它是 False,则输出将被堆叠,否则将被连接。
如果输入是标量,则输出将被堆叠。
- 返回类型:
任何