jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[source]#
用于约束 jitted 计算内 Array 分片的机制
对于 GSPMD 分区器,这是一个严格的约束,而不是提示。有关如何使用此函数的示例,请参阅 分布式数组和自动并行化。
- 参数:
x – 将对其分片进行约束的 jax.Array 的 Pytree
shardings – 分片规范的 Pytree。有效值与
jax.experimental.pjit()
的in_shardings
参数相同。
- 返回值:
具有指定分片约束的 jax.Array 的 Pytree。
- 返回类型:
x_with_shardings