jax.lax.with_sharding_constraint

jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[source]#

用于约束 jitted 计算内 Array 分片的机制

对于 GSPMD 分区器,这是一个严格的约束,而不是提示。有关如何使用此函数的示例,请参阅 分布式数组和自动并行化

参数:
  • x – 将对其分片进行约束的 jax.Array 的 Pytree

  • shardings – 分片规范的 Pytree。有效值与 jax.experimental.pjit()in_shardings 参数相同。

返回值:

具有指定分片约束的 jax.Array 的 Pytree。

返回类型:

x_with_shardings