jax.lax.ppermute

内容

jax.lax.ppermute#

jax.lax.ppermute(x, axis_name, perm)[source]#

根据排列 perm 执行集体排列。

如果 x 是一个 pytree,那么结果等效于将此函数映射到树中的每个叶子。

此函数是 CollectivePermute HLO 的模拟。

参数:
  • x – 具有映射轴 (命名为 axis_name) 的数组。

  • axis_name – 用于命名 pmapped 轴的可哈希 Python 对象 (有关更多详细信息,请参见 jax.pmap() 文档)。

  • perm – 整数对列表,表示 (source_index, destination_index) 对,用于编码如何对名为 axis_name 的映射轴进行洗牌。整数被视为映射轴 axis_name 中的索引。任何两个对不应该具有相同的源索引或相同的目标索引。对于轴 axis_name 中的每个索引,它不对应于 perm 中的目标索引,结果中相应的值将填充为适当类型的零。

返回值:

x 形状相同的数组,沿着轴 axis_name 的切片根据置换 permx 中收集。