jax.lax.ppermute#

jax.lax.ppermute(x, axis_name, perm)[源代码]#

根据置换 perm 执行集体置换。

如果 x 是一个 pytree,那么结果等价于将此函数映射到树中的每个叶子节点。

此函数是 CollectivePermute HLO 的一个模拟。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 可哈希的 Python 对象,用于命名一个 pmapped 轴(有关详细信息,请参阅 jax.pmap() 文档)。

  • perm – 由整数对组成的列表,表示 (source_index, destination_index) 对,用于编码名为 axis_name 的映射轴应如何洗牌。整数值被视为映射轴 axis_name 的索引。任何两个配对不应具有相同的源索引或相同的目标索引。对于轴 axis_name 的每个不对应于 perm 中目标索引的索引,结果中的相应值将填充适当类型的零。

返回:

x 形状相同的数组,其沿轴 axis_name 的切片根据置换 permx 中收集。