jax.ensure_compile_time_eval#

jax.ensure_compile_time_eval()[源代码]#

上下文管理器,确保在跟踪/编译时(或错误)进行评估。

一些 JAX API,如 jax.jit()jax.lax.scan() 涉及暂存,即延迟数值表达式(如 jax.numpy 函数应用)的评估,以便在评估相应的 Python 表达式时,不是立即执行这些计算,而是单独执行,例如在优化编译之后。但是,这种延迟可能是不希望的。例如,可能需要数值来评估 Python 控制流,因此不能延迟它们的评估。另一个例子是,出于性能原因,确保编译时评估(或“常量折叠”)可能是有益的。

此上下文管理器确保 JAX 计算被立即评估。如果无法立即评估,则会引发 ConcretizationTypeError

这是一个人为的例子

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  with jax.ensure_compile_time_eval():
    y = jnp.sin(3.0)
    z = jnp.sin(y)
    z_positive = z > 0
  if z_positive:  # z_positive is usable in Python control flow
    return jnp.sin(x)
  else:
    return jnp.cos(x)

这是一个来自 jax-ml/jax#3974 的真实示例

import jax
import jax.numpy as jnp
from jax import random

@jax.jit
def jax_fn(x):
  with jax.ensure_compile_time_eval():
    y = random.randint(random.key(0), (1000,1000), 0, 100)
  y2 = y @ y
  x2 = jnp.sum(y2) * x
  return x2

通常可以通过简单地将常量表达式“提升”出相应的暂存 API 来实现类似的行为

y = random.randint(random.key(0), (1000,1000), 0, 100)

@jax.jit
def jax_fn(x):
  y2 = y @ y
  x2 = jnp.sum(y2)*x
  return x2

但在某些情况下,使用此上下文管理器可能更方便。