jax.experimental.shard_map.shard_map

内容

jax.experimental.shard_map.shard_map#

jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#

将函数映射到数据的分片上。

注意

shard_map 是一个实验性 API,目前仍在不断变化。有关分片数据的介绍,请参阅 并行编程简介。有关使用 shard_map 的更深入说明,请参阅 使用 shard_map 进行 SPMD 多设备并行化.

参数:
  • f (Callable) – 要映射的可调用对象。每次应用 f,或 f 的“实例”,都会将映射的输入参数的一个分片作为输入,并生成输出的一个分片。

  • mesh (Mesh | AbstractMesh) – 表示数据的切片和函数 f 实例执行的设备数组的 jax.sharding.MeshMesh 的名称可以在 f 中的集体通信操作中使用。这通常由一个实用函数创建,例如 jax.experimental.mesh_utils.create_device_mesh()

  • in_specs (Specs) – 具有 PartitionSpec 实例作为叶子的 pytree,其树结构是将要映射的 args 元组的树前缀。类似于 NamedSharding,每个 PartitionSpec 表示对应参数(或参数子树)应该如何在 mesh 的命名轴上切片。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿着该位置轴切片对应参数数组轴;不提及轴名称表示复制。如果参数或参数子树具有相应的 None 规范,则该参数不会被切片。

  • out_specs (Specs) – 具有 PartitionSpec 实例作为叶子的 pytree,其树结构是 f 输出的树前缀。每个 PartitionSpec 表示对应输出切片应该如何连接。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿着对应位置轴连接该网格轴切片的切片。不提及 mesh 轴名称表示一个承诺,即输出值沿着该网格轴相等,并且应该只生成一个值而不是连接。

  • check_rep (bool) – 如果为 True(默认),则启用额外的有效性检查和自动微分优化。有效性检查涉及任何未在 out_specs 中提及的网格轴名称是否与 f 输出的复制方式一致。如果在 f 中使用 Pallas 内核,则必须设置为 False。

  • auto (frozenset[AxisName]) – (实验性)来自 mesh 的可选轴名称集,我们不会在这些名称上对数据进行切片或映射函数,而是允许编译器控制切片。这些名称不能在 in_specsout_specsf 中的通信集体中使用。

返回:

一个可调用对象,它根据 meshin_specs 对切片的数据应用输入函数 f

示例

有关示例,请参阅 并行编程简介使用 shard_map 的 SPMD 多设备并行性