jax.lax.ppermute#
- jax.lax.ppermute(x, axis_name, perm)[source]#
根据排列
perm
执行集体排列。如果
x
是一个 pytree,那么结果等效于将此函数映射到树中的每个叶子。此函数是 CollectivePermute HLO 的模拟。
- 参数:
x – 具有映射轴 (命名为
axis_name
) 的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象 (有关更多详细信息,请参见
jax.pmap()
文档)。perm – 整数对列表,表示
(source_index, destination_index)
对,用于编码如何对名为axis_name
的映射轴进行洗牌。整数被视为映射轴axis_name
中的索引。任何两个对不应该具有相同的源索引或相同的目标索引。对于轴axis_name
中的每个索引,它不对应于perm
中的目标索引,结果中相应的值将填充为适当类型的零。
- 返回值:
与
x
形状相同的数组,沿着轴axis_name
的切片根据置换perm
从x
中收集。