jax.disable_jit

内容

jax.disable_jit#

jax.disable_jit(disable=True)[source]#

在动态上下文中禁用 jit() 行为的上下文管理器。

为了调试,拥有一个在动态上下文中禁用 jit() 的机制非常有用。请注意,这不仅禁用了用户对 jit() 的显式使用,还将删除 JAX 库使用的任何隐式 JIT 编译:这包括对传递给高级原语(如 scan()while_loop())的 bodycond 函数的隐式 JIT 计算、jax.numpy 函数实现中使用的 JIT,以及任何其他在 API 实现中使用 jit() 的情况。但是请注意,即使在 disable_jit 下,单个原语操作仍将像在正常的急切逐操作执行中一样由 XLA 编译。

对 jit 函数的参数具有数据依赖关系的值会被追踪并抽象化。例如,一个抽象值可能是一个 ShapedArray 实例,表示具有给定形状和数据类型的所有可能数组的集合,但不表示具有特定值的具体数组。如果您在 jit 函数中使用了良性的副作用操作(例如打印),您可能会注意到这些情况。

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))  
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
[5 7 9]

这里 y 已被 jit() 抽象化为 ShapedArray,它表示一个具有固定形状和类型但值任意的数组。y 的值也会被追踪。如果我们想在调试时看到具体的值,并避免追踪器,我们可以使用 disable_jit() 上下文管理器

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
参数:

disable (布尔值)