jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[源代码]#

在 pmapped 轴 axis_name 上对 x 进行全归约求和。

如果 x 是一个 pytree,那么结果等价于将此函数映射到树中的每个叶子节点。

布尔类型 (boolean) 的输入会在归约前转换为整数。

参数:
  • x – 具有名为 axis_name 的映射轴的数组(或数组列表)。

  • axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(有关详细信息,请参阅 jax.pmap() 文档)。

  • axis_index_groups – 可选的包含轴索引列表的列表(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将在前两个和后两个副本上执行 psum)。组必须精确地覆盖所有轴索引一次。

返回值:

x 具有相同形状的数组(或数组列表),表示沿轴 axis_name 进行全归约求和的结果。

示例

例如,假设有 4 个 XLA 设备可用

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]

假设我们想在两组之间执行 psum,一组包含 device0device1,另一组包含 device2device3

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]

使用 2D 形状的 x 的示例。每一行来自一个设备的数据。

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

跨所有设备的完整 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]]

在两组之间执行 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
 [ 4  6  8 10]
 [20 22 24 26]
 [20 22 24 26]]