jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[源代码]#

约束 JIT 编译计算中数组分片的机制

这是 GSPMD 分区器的严格约束,而不是提示。有关如何使用此函数的示例,请参阅 分布式数组和自动并行化

参数:
  • x – 将约束其分片的 jax.Arrays 的 PyTree

  • shardings – 分片规范的 PyTree。有效值与 jax.experimental.pjit()in_shardings 参数相同。

返回值:

具有指定分片约束的 jax.Arrays 的 PyTree。

返回类型:

x_with_shardings