jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[source]#
在所有副本中收集 x 的值。
如果
x
是一个 pytree,那么结果等效于将此函数映射到树中的每个叶节点。这等效于,但比 all_to_all(broadcast(x)) 快。
- 参数::
x – 具有名为
axis_name
的映射轴的数组。axis_name – 可散列的 Python 对象,用于命名 pmapped 轴(有关更多详细信息,请参见
jax.pmap()
文档)。axis_index_groups – 可选的列表,包含轴索引列表(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将在第一个和最后两个副本上运行 all gather)。这些组必须完全覆盖所有轴索引一次,并且所有组必须具有相同的大小。
axis – 位置轴,沿着
axis_name
的块将被连接到该轴。tiled – 当
False
时,这些块将被堆叠到输出中索引axis
的一个新的位置轴上。当True
时,axis
必须引用一个现有的位置维度,并且这些块将被连接到该维度。
- 返回值::
表示沿轴
axis_name
进行全收集操作的结果的数组。 形状与x.shape
相同,但当
tiled
为False
时,在位置axis
处有一个新维度,其大小等于轴axis_name
的大小,当
tiled
为True
时,位置axis
处维度的尺寸将乘以轴axis_name
的尺寸。
例如,当有 4 个可用的 XLA 设备时
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的示例,组按偶数和奇数设备 ID 分割
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]