jax.lax.psum_scatter#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[source]#
类似于
psum(x, axis_name)
,但每个设备仅保留结果的一部分。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
计算的结果与psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
结果沿着映射轴被分散。计算
psum(x, axis_name)
的一种有效算法是执行psum_scatter
,然后执行all_gather
,实质上是计算all_gather(psum_scatter(x, axis_name))
。因此,我们可以将psum_scatter
视为psum
的“前半部分”。- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参阅
jax.pmap()
文档)。scatter_dimension – 一个位置轴,沿着
axis_name
的所有归约结果将散布到该轴上。axis_index_groups – 可选的整数列表列表,包含轴索引。例如,对于大小为 4 的轴,
axis_index_groups=[[0, 1], [2, 3]]
将在第一个两个和最后两个轴索引上运行归约散布。组必须完全覆盖所有轴索引,并且所有组的大小必须相同。tiled – 表示是否使用秩保持“平铺”行为的布尔值。当
False
(默认值)时,scatter_dimension
中维度的尺寸必须与轴axis_name
的尺寸匹配(如果给定axis_index_groups
,则与组大小匹配)。沿着scatter_dimension
散布所有归约结果后,输出通过移除scatter_dimension
来压缩,因此结果的秩低于输入。当True
时,scatter_dimension
中维度的尺寸必须能够被轴axis_name
的尺寸整除(如果给定axis_index_groups
,则与组大小匹配),并且scatter_dimension
轴将被保留(因此结果的秩与输入相同)。
- 返回值:
与
x
形状类似的数组,除了位置scatter_dimension
中维度的尺寸将被轴axis_name
的尺寸除(当tiled=True
时),或者位置scatter_dimension
中的维度将被消除(当tiled=False
时)。
例如,如果有 4 个 XLA 设备可用
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用平铺
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的示例
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]