jax.lax.psum

内容

jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[source]#

在 pmapped 轴 axis_name 上对 x 计算全归约求和。

如果 x 是一个 pytree,则结果等效于将此函数映射到树中的每个叶子。

布尔类型的数据在归约之前会被转换为整数。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 用于命名 pmapped 轴的可散列 Python 对象(有关更多详细信息,请参阅 jax.pmap() 文档)。

  • axis_index_groups – 可选的列表,包含轴索引的列表(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将对前两个和后两个副本执行 psum)。组必须完全覆盖所有轴索引。

返回值:

x 形状相同的数组,表示沿轴 axis_name 的全归约求和的结果。

示例

例如,有 4 个可用的 XLA 设备

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]

假设我们希望在两个组之间执行psum操作,一个组包含device0device1,另一个组包含device2device3

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]

使用二维形状的x作为示例。每一行表示一个设备的数据。

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

跨所有设备执行完整的psum操作

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]]

在两个组之间执行psum操作

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
 [ 4  6  8 10]
 [20 22 24 26]
 [20 22 24 26]]