预先降低和编译#

JAX 提供了几种转换,例如 jax.jitjax.pmap,返回一个在加速器或 CPU 上编译并运行的函数。正如 JIT 首字母缩写词所表示的那样,所有编译都在执行的即时进行。

某些情况需要预先 (AOT) 编译。 当你希望在执行时之前完全编译,或者你希望控制编译过程的不同部分何时发生时,JAX 会为你提供一些选项。

首先,让我们回顾一下编译的各个阶段。假设 fjax.jit() 输出的函数/可调用对象,例如对于某个输入可调用对象 F,有 f = jax.jit(F)。当使用参数调用它时,例如 f(x, y),其中 xy 是数组,JAX 按以下顺序执行操作:

  1. 分离阶段:将原始 Python 可调用对象 F 的特定版本分离到内部表示形式。这种特定化反映了 F 对从参数 xy 的属性(通常是它们的形状和元素类型)推断出的输入类型的限制。

  2. 降级:将这种经过特定化、分离出的计算降级为 XLA 编译器的输入语言 StableHLO。

  3. 编译:编译降级后的 HLO 程序,生成针对目标设备(CPU、GPU 或 TPU)优化的可执行文件。

  4. 执行:使用数组 xy 作为参数执行编译后的可执行文件。

JAX 的 AOT API 可以让您直接控制步骤 #2、#3 和 #4(但 不是 #1),以及沿途的其他一些功能。一个示例:

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

请注意,降级后的对象只能在降级它们的同一进程中使用。有关导出用例,请参阅 导出和序列化 API。

有关降级和编译函数提供的更多功能,请参阅 jax.stages 文档。

在上面的 jax.jit 的位置,您也可以 lower(...) jax.pmap() 的结果,以及 pjitxmap (分别来自 jax.experimental.pjitjax.experimental.maps)。在每种情况下,您都可以类似地 compile() 结果。

jit 的所有可选参数(例如 static_argnums)在相应的降级、编译和执行中都会被遵守。对于 pmappjitxmap 也是如此。

在上面的示例中,我们可以用具有 shapedtype 属性的任何对象替换 lower 的参数。

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32)

更一般地说,lower 只需要其参数在结构上提供 JAX 必须知道的用于特定化和降级的信息。对于像上面这样的典型数组参数,这意味着 shapedtype 字段。相比之下,对于静态参数,JAX 需要实际的数组值(更多内容请见下面)。

使用与其降级不兼容的参数调用 AOT 编译的函数会引发错误。

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

类似地,AOT 编译的函数不能被 JAX 的即时转换转换,例如 jax.jitjax.grad()jax.vmap()

使用静态参数进行降级#

使用静态参数进行降级强调了传递给 jax.jit 的选项、传递给 lower 的参数以及调用生成的编译函数所需的参数之间的交互。继续上面的示例:

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

lower 的结果直接序列化以在不同的进程中使用是不安全的。有关此目的的更多 API,请参阅 导出和序列化

请注意,这里的 lower 像往常一样接受两个参数,但随后的编译函数仅接受剩余的非静态第二个参数。静态的第一个参数(值 7)在降级时被视为常量,并构建到降级的计算中,它可能与其他常量折叠在一起。在这种情况下,它乘以 2 的运算被简化,从而得到常量 14。

虽然上面的 lower 的第二个参数可以用空心的形状/dtype 结构替换,但静态的第一个参数必须是一个具体的值。否则,降级会出错:

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32)

AOT 编译的函数不能被转换#

编译的函数专门针对一组特定的参数“类型”,例如在我们的运行示例中具有特定形状和元素类型的数组。从 JAX 的内部角度来看,诸如 jax.vmap() 之类的转换会以使编译后的类型签名无效的方式更改函数的类型签名。作为一项策略,JAX 只是不允许编译的函数参与转换。示例:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 参与自动微分时(例如 jax.grad()),也会引发类似的错误。为了保持一致性,也不允许通过 jax.jit 进行转换,即使 jit 没有实际修改其参数的类型签名。

调试信息和分析(如果可用)#

除了主要的 AOT 功能(单独和显式的降级、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些其他功能,以帮助进行调试和收集编译器反馈。

例如,如上面的初始示例所示,降级函数通常提供文本表示形式。编译后的函数也是如此,并且还提供来自编译器的成本和内存分析。所有这些都是通过 jax.stages.Loweredjax.stages.Compiled 对象上的方法提供的(例如,上面的 lowered.as_text()compiled.cost_analysis())。

这些方法旨在作为手动检查和调试的辅助工具,而不是作为可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这带来了两个重要的注意事项:

  1. 如果 JAX 的当前后端上某些功能不可用,则该方法会返回一些无关紧要的内容(并且类似于 False)。例如,如果 JAX 底层的编译器不提供成本分析,则 compiled.cost_analysis() 将为 None

  2. 如果某些功能可用,则对相应方法提供的内容仍然只有非常有限的保证。返回值不需要在 JAX 配置、后端/平台、版本甚至方法的调用之间在类型、结构或值方面保持一致。JAX 无法保证某一天 compiled.cost_analysis() 的输出在第二天保持不变。

如有疑问,请参阅 jax.stages 的软件包 API 文档。

检查分离出的计算#

本说明顶部列表中的阶段 #1 提到了在降级之前的特定化和分离。JAX 内部对专门针对其参数类型的函数的概念并不总是内存中的具体数据结构。要显式构造 JAX 在内部 Jaxpr 中间语言中对函数的特定化的视图,请参阅 jax.make_jaxpr()