jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[源代码]#
约束 JIT 编译计算中数组分片的机制
这是 GSPMD 分区器的严格约束,而不是提示。有关如何使用此函数的示例,请参阅 分布式数组和自动并行化。
- 参数:
x – 将约束其分片的 jax.Arrays 的 PyTree
shardings – 分片规范的 PyTree。有效值与
jax.experimental.pjit()
的in_shardings
参数相同。
- 返回值:
具有指定分片约束的 jax.Arrays 的 PyTree。
- 返回类型:
x_with_shardings