提前降低和编译#
JAX 提供了多种转换,例如 jax.jit
和 jax.pmap
,返回一个在加速器或 CPU 上编译并运行的函数。正如 JIT 首字母缩略词所暗示的那样,所有编译都发生在执行时的 即时。
有些情况需要 提前 (AOT) 编译。当您想要在执行时间之前完全编译,或者您想要控制编译过程的不同部分何时发生时,JAX 为您提供了一些选项。
首先,让我们回顾一下编译阶段。假设 f
是由 jax.jit()
输出的函数/可调用对象,例如 f = jax.jit(F)
,其中 F
是某个输入可调用对象。当它被调用时,例如 f(x, y)
,其中 x
和 y
是数组,JAX 会按照以下顺序执行:
分阶段原始 Python 可调用对象
F
的专门版本到内部表示。专门化反映了F
的限制,该限制仅限于从参数x
和y
的属性(通常是它们的形状和元素类型)推断出的输入类型。降低此专门的分阶段计算到 XLA 编译器的输入语言 StableHLO。
编译降低的 HLO 程序以生成针对目标设备(CPU、GPU 或 TPU)的优化可执行文件。
执行已编译的可执行文件,数组
x
和y
作为参数。
JAX 的 AOT API 使您可以直接控制步骤 #2、#3 和 #4(但 不包括 #1),以及沿途的一些其他功能。一个例子
>>> import jax
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> lowered = jax.jit(f).lower(x, y)
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
return %1 : tensor<i32>
}
}
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0
>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)
请注意,降级的对象只能在它们被降级的同一个进程中使用。有关导出用例,请参阅导出和序列化 API。
有关降级和编译函数提供的功能的更多详细信息,请参阅jax.stages
文档。
除了上面的jax.jit
之外,您还可以lower(...)
jax.pmap()
的结果,以及pjit
和xmap
(分别来自jax.experimental.pjit
和jax.experimental.maps
)。在每种情况下,您都可以类似地compile()
结果。
所有可选参数jit
(例如static_argnums
)在相应的降级、编译和执行中得到遵守。同样适用于pmap
,pjit
和xmap
。
在上面的示例中,我们可以用任何具有shape
和dtype
属性的对象替换lower
的参数
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32)
更一般地说,lower
只需要其参数在结构上提供 JAX 为专门化和降级所需的知识。对于像上面这样的典型数组参数,这意味着shape
和dtype
字段。相比之下,对于静态参数,JAX 需要实际的数组值(有关此内容的更多信息,请参见下面)。
使用与其降级不兼容的参数调用 AOT 编译函数会导致错误
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]
同样地,AOT 编译函数不能被 JAX 的即时转换(例如jax.jit
,jax.grad()
和jax.vmap()
)进行转换。
使用静态参数进行降级#
使用静态参数进行降级强调了传递给jax.jit
的选项、传递给lower
的参数以及调用结果编译函数所需的参数之间的交互。继续我们上面的示例
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<14> : tensor<i32>
%0 = stablehlo.add %c, %arg0 : tensor<i32>
return %0 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)
lower
的结果不安全直接序列化以在不同进程中使用。有关此目的的其他 API,请参阅导出和序列化。
请注意,这里的lower
像往常一样接受两个参数,但随后的编译函数只接受剩余的非静态第二个参数。静态第一个参数(值 7)在降级时被视为常量并内置到降级计算中,它可能会与其他常量折叠在一起。在这种情况下,它与 2 相乘被简化为常量 14。
尽管上面的lower
的第二个参数可以用空形状/dtype 结构替换,但静态第一个参数必须是具体值。否则,降级将出错
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32)
AOT 编译函数不能被转换#
编译函数专门针对特定的一组参数“类型”,例如在我们正在运行的示例中具有特定形状和元素类型的数组。从 JAX 的内部角度来看,诸如jax.vmap()
之类的转换以使编译类型签名失效的方式改变函数的类型签名。作为一项策略,JAX 只是不允许编译函数参与转换。例子
>>> def g(x):
... assert x.shape == (3, 2)
... return x @ jnp.ones(2)
>>> def make_z(*shape):
... return jnp.arange(np.prod(shape)).reshape(shape)
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()
>>> jax.vmap(g_jit)(zs)
Array([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)
>>> jax.vmap(g_aot)(zs)
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>
当g_aot
参与自动微分(例如jax.grad()
)时,也会出现类似的错误。为了保持一致性,jax.jit
的转换也被禁止,即使jit
并没有有意义地修改其参数的类型签名。
调试信息和分析(如果有)#
除了主要的 AOT 功能(单独且明确的降级、编译和执行)之外,JAX 的各种 AOT 阶段还提供了一些额外的功能来帮助进行调试和收集编译器反馈。
例如,如上面最初的示例所示,降级函数通常提供文本表示。编译函数也是如此,并且还提供编译器的成本和内存分析。所有这些都是通过jax.stages.Lowered
和jax.stages.Compiled
对象(例如,上面的lowered.as_text()
和compiled.cost_analysis()
)上的方法提供的。
这些方法旨在作为手动检查和调试的辅助工具,而不是可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这导致了两个重要的警告
如果 JAX 当前后端上某些功能不可用,那么该方法将返回一些琐碎的内容(并且类似于
False
)。例如,如果 JAX 底层的编译器不提供成本分析,那么compiled.cost_analysis()
将为None
。如果某些功能可用,那么仍然对相应方法提供的内容有非常有限的保证。返回值不必在类型、结构或值上在 JAX 配置、后端/平台、版本甚至方法调用之间保持一致。JAX 不能保证
compiled.cost_analysis()
的输出在一天内会在第二天保持不变。
如有疑问,请参阅jax.stages
的包 API 文档。
检查分阶段计算#
本说明顶部的列表中的第 1 阶段提到专门化和分阶段,在降级之前。JAX 对专门针对其参数类型的函数的内部概念并不总是内存中的一种具体化的数据结构。要显式构造对 JAX 在内部Jaxpr 中间语言 中对函数专门化的视图,请参阅jax.make_jaxpr()
。