jax.experimental.shard_map.shard_map#

jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[源代码]#

将函数映射到数据的分片上。

注意

shard_map 是一个实验性 API,仍可能发生变化。有关分片数据的介绍,请参阅 并行计算简介。要更深入地了解如何使用 shard_map,请参阅 使用 shard_map 进行 SPMD 多设备并行计算

参数:
  • f (可调用对象) – 要映射的可调用对象。f 的每次应用,或者 f 的“实例”,都会将映射过的参数的分片作为输入,并生成输出的分片。

  • mesh (Mesh | AbstractMesh) – 一个 jax.sharding.Mesh,表示用于分片数据的设备数组以及执行 f 实例的设备。 Mesh 的名称可用于 f 中的集体通信操作。这通常由诸如 jax.experimental.mesh_utils.create_device_mesh() 之类的实用程序函数创建。

  • in_specs (Specs) – 一个以 PartitionSpec 实例作为叶子的 pytree,其树结构是要映射的 args 元组的树前缀。类似于 NamedSharding,每个 PartitionSpec 表示相应的参数(或参数子树)应如何沿着 mesh 的命名轴进行分片。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿着该位置轴分片相应的参数数组轴;不提及轴名称表示复制。如果参数或参数子树的对应规范为 None,则该参数不进行分片。

  • out_specs (Specs) – 一个以 PartitionSpec 实例作为叶子的 pytree,其树结构是 f 输出的树前缀。每个 PartitionSpec 表示应如何连接相应的输出分片。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿着相应的位置轴连接该 mesh 轴的分片。不提及 mesh 轴名称表示承诺输出值在该 mesh 轴上相等,并且不应连接,而只应生成单个值。

  • check_rep (bool) – 如果为 True(默认),则启用额外的有效性检查和自动微分优化。有效性检查涉及是否在 out_specs 中未提及的任何 mesh 轴名称与 f 的输出的复制方式一致。如果在 f 中使用 Pallas 内核,则必须将其设置为 False。

  • auto (frozenset[AxisName]) – (实验性)mesh 中可选的一组轴名称,我们不会对这些轴分片数据或映射函数,而是允许编译器控制分片。这些名称不能在 in_specsout_specsf 中的通信集合中使用。

返回:

一个可调用对象,它根据 meshin_specs 对分片的数据应用输入函数 f

示例

有关示例,请参阅 并行计算简介使用 shard_map 进行 SPMD 多设备并行计算