jax.disable_jit#
- jax.disable_jit(disable=True)[源代码]#
上下文管理器,在其动态上下文中禁用
jit()
的行为。对于调试来说,拥有一个在动态上下文中禁用所有地方的
jit()
的机制非常有用。请注意,这不仅会禁用用户显式使用的jit()
,还会删除 JAX 库使用的任何隐式 JIT 编译:这包括传递给高级原语(如scan()
和while_loop()
)的 body 和 cond 函数的隐式 JIT 计算,jax.numpy
函数的实现中使用的 JIT,以及 API 实现中使用jit()
的任何其他情况。但是请注意,即使在 disable_jit 下,单个原始操作仍然会像正常的逐个操作的执行一样被 XLA 编译。具有对 JIT 函数的参数的数据依赖的值会被跟踪和抽象。例如,抽象值可以是
ShapedArray
实例,表示具有给定形状和 dtype 的所有可能数组的集合,但不表示具有特定值的具体数组。如果在 JIT 函数中使用良性副作用操作(如 print),您可能会注意到这些。>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...> [5 7 9]
这里,
y
已被jit()
抽象为ShapedArray
,它表示具有固定形状和类型但具有任意值的数组。y
的值也被跟踪。 如果我们想在调试时看到一个具体的值,并避免跟踪器,我们可以使用disable_jit()
上下文管理器。>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]
- 参数:
disable (bool)