jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[源代码]#
跨所有副本收集 x 的值。
如果
x
是一个 pytree,则结果等同于将此函数映射到树中的每个叶子节点。这等效于 all_to_all(broadcast(x)),但速度更快。
- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(有关更多详细信息,请参阅
jax.pmap()
文档)。axis_index_groups – 可选的列表列表,包含轴索引(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将在前面两个和最后两个副本上运行 all gather)。组必须恰好覆盖所有轴索引一次,并且所有组的大小必须相同。
axis – 一个位置轴,沿着
axis_name
的块将连接到该轴中。tiled – 当
False
时,分块将在输出中axis
索引处堆叠到一个新的位置轴上。当True
时,axis
必须引用一个现有的位置维度,并且分块将连接到该维度中。
- 返回:
表示沿轴
axis_name
进行全收集操作的结果的数组。形状与x.shape
相同,但是当
tiled
为False
时,在位置axis
处会有一个新的维度,其大小等于轴axis_name
的大小,当
tiled
为True
时,位置axis
中的维度大小将乘以轴axis_name
的大小。
例如,如果有 4 个可用的 XLA 设备
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的示例,按偶数和奇数设备 ID 分组
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]