jax.lax.psum_scatter#

jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[源代码]#

类似于 psum(x, axis_name),但每个设备只保留部分结果。

例如,psum_scatter(x, axis_name, scatter_dimension=0, tiled=False) 计算的值与 psum(x, axis_name)[axis_index(axis_name)] 相同,但效率更高。因此,psum 的结果会分散在映射的轴上。

一种计算 psum(x, axis_name) 的高效算法是执行 psum_scatter,然后执行 all_gather,本质上是计算 all_gather(psum_scatter(x, axis_name))。因此,我们可以将 psum_scatter 视为 psum 的“前半部分”。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 用于命名映射轴的可哈希 Python 对象(有关详细信息,请参阅 jax.pmap() 文档)。

  • scatter_dimension – 一个位置轴,沿 axis_name 的 all-reduce 结果将分散到该轴中。

  • axis_index_groups – 可选的整数列表的列表,包含轴索引。例如,对于大小为 4 的轴,axis_index_groups=[[0, 1], [2, 3]] 将在头两个和最后两个轴索引上运行 reduce-scatter。组必须恰好覆盖所有轴索引一次,并且所有组的大小必须相同。

  • tiled – 一个布尔值,表示是否使用保留秩的“平铺”行为。当 False(默认值)时,scatter_dimension 中的维度大小必须与轴 axis_name 的大小(或如果给定 axis_index_groups 则为组大小)匹配。沿 scatter_dimension 分散 all-reduce 结果后,通过删除 scatter_dimension 来压缩输出,因此结果的秩低于输入。当 True 时,scatter_dimension 中的维度大小必须可被轴 axis_name 的大小(或如果给定 axis_index_groups 则为组大小)整除,并且保留 scatter_dimension 轴(因此结果与输入具有相同的秩)。

返回值:

形状与 x 相似的数组,只是位置 scatter_dimension 中的维度大小除以轴 axis_name 的大小(当 tiled=True 时),或者位置 scatter_dimension 中的维度被消除(当 tiled=False 时)。

例如,如果可以使用 4 个 XLA 设备

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]

如果使用平铺

>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
 [28]
 [32]
 [36]]

使用 axis_index_groups 的示例

>>> def f(x):
...   return jax.lax.psum_scatter(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
 [20 22]
 [12 14]
 [16 18]]