jax.lax.psum_scatter

jax.lax.psum_scatter#

jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[source]#

类似于 psum(x, axis_name),但每个设备仅保留结果的一部分。

例如,psum_scatter(x, axis_name, scatter_dimension=0, tiled=False) 计算的结果与 psum(x, axis_name)[axis_index(axis_name)] 相同,但效率更高。因此,psum 结果沿着映射轴被分散。

计算 psum(x, axis_name) 的一种有效算法是执行 psum_scatter,然后执行 all_gather,实质上是计算 all_gather(psum_scatter(x, axis_name))。因此,我们可以将 psum_scatter 视为 psum 的“前半部分”。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参阅 jax.pmap() 文档)。

  • scatter_dimension – 一个位置轴,沿着 axis_name 的所有归约结果将散布到该轴上。

  • axis_index_groups – 可选的整数列表列表,包含轴索引。例如,对于大小为 4 的轴,axis_index_groups=[[0, 1], [2, 3]] 将在第一个两个和最后两个轴索引上运行归约散布。组必须完全覆盖所有轴索引,并且所有组的大小必须相同。

  • tiled – 表示是否使用秩保持“平铺”行为的布尔值。当 False(默认值)时,scatter_dimension 中维度的尺寸必须与轴 axis_name 的尺寸匹配(如果给定 axis_index_groups,则与组大小匹配)。沿着 scatter_dimension 散布所有归约结果后,输出通过移除 scatter_dimension 来压缩,因此结果的秩低于输入。当 True 时,scatter_dimension 中维度的尺寸必须能够被轴 axis_name 的尺寸整除(如果给定 axis_index_groups,则与组大小匹配),并且 scatter_dimension 轴将被保留(因此结果的秩与输入相同)。

返回值:

x 形状类似的数组,除了位置 scatter_dimension 中维度的尺寸将被轴 axis_name 的尺寸除(当 tiled=True 时),或者位置 scatter_dimension 中的维度将被消除(当 tiled=False 时)。

例如,如果有 4 个 XLA 设备可用

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]

如果使用平铺

>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
 [28]
 [32]
 [36]]

使用 axis_index_groups 的示例

>>> def f(x):
...   return jax.lax.psum_scatter(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
 [20 22]
 [12 14]
 [16 18]]