jax.lax.optimization_barrier#

jax.lax.optimization_barrier(operand, /)[源代码]#

阻止编译器将操作移动到屏障两侧。

优化屏障有许多可能的用途

  • 优化屏障确保在任何依赖于屏障输出的操作之前评估所有输入。这可以用来强制执行特定的操作顺序。

  • 优化屏障阻止公共子表达式消除。JAX 使用此方法来实现重新物化。

  • 优化屏障阻止编译器融合。也就是说,编译器可能不会将屏障之前的操作与屏障之后的操作融合到同一个内核中。

JAX 没有为优化屏障定义导数或批处理规则。

优化屏障在编译函数之外不起作用。

参数:

operand – JAX 值的 pytree。

返回:

一个 JAX 值的 pytree,其结构和内容与 operand 相同。

示例

阻止两次调用 sin 之间的公共子表达式消除

>>> def f(x):
...   return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)