导出和序列化暂存计算#
提前降低和编译 API 会生成可用于调试或在同一进程中进行编译和执行的对象。有时您可能希望序列化一个降低的 JAX 函数,以便在单独的进程中进行编译和执行,也许是在稍后的时间。这将使您能够
在另一个进程或机器中编译和执行该函数,而无需访问 JAX 程序,也无需重复暂存和降低,例如,在推理系统中。
在一台无法访问您想要稍后编译和执行该函数的加速器的机器上跟踪和降低函数。
归档 JAX 函数的快照,例如,以便稍后能够重现您的结果。 注意:查看此用例的兼容性保证。
有关更多详细信息,请参阅 jax.export
API 参考。
这是一个例子
>>> import re
>>> import numpy as np
>>> import jax
>>> from jax import export
>>> def f(x): return 2 * x * x
>>> exported: export.Exported = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((), np.float32))
>>> # You can inspect the Exported object
>>> exported.fun_name
'f'
>>> exported.in_avals
(ShapedArray(float32[]),)
>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = ""}) {
>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()
>>> # The serialized function can later be rehydrated and called from
>>> # another JAX computation, possibly in another process.
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
>>> rehydrated_exp.in_avals
(ShapedArray(float32[]),)
>>> def callee(y):
... return 3. * rehydrated_exp.call(y * 4.)
>>> callee(1.)
Array(96., dtype=float32)
序列化分为两个阶段
导出以生成一个
jax.export.Exported
对象,该对象包含降级函数的 StableHLO 以及从另一个 JAX 函数调用它所需的元数据。我们计划添加代码以从 TensorFlow 生成Exported
对象,并使用来自 TensorFlow 和 PyTorch 的Exported
对象。使用 flatbuffers 格式实际序列化为字节数组。有关与 TensorFlow 互操作的另一种序列化为 TensorFlow 图的方法,请参阅 与 TensorFlow 的互操作。
支持反向模式 AD#
序列化可以选择性地支持高阶反向模式 AD。这是通过将原始函数的 jax.vjp()
与原始函数一起序列化来实现的,直到用户指定的阶数(默认为 0,这意味着重新水合的函数不能进行微分)。
>>> import jax
>>> from jax import export
>>> from typing import Callable
>>> def f(x): return 7 * x * x * x
>>> # Serialize 3 levels of VJP along with the primal function
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
>>> rehydrated_f: Callable = export.deserialize(blob).call
>>> rehydrated_f(0.1) # 7 * 0.1^3
Array(0.007, dtype=float32)
>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2
Array(0.21000001, dtype=float32)
>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1
Array(4.2, dtype=float32)
>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2
Array(42., dtype=float32)
>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)
Traceback (most recent call last):
ValueError: No VJP is available
请注意,VJP 函数在序列化时以延迟方式计算,此时 JAX 程序仍然可用。这意味着它遵守 JAX VJP 的所有特性,例如 jax.custom_vjp()
和 jax.remat()
。
请注意,重新水合的函数不支持任何其他转换,例如前向模式 AD (jvp) 或 jax.vmap()
。
兼容性保证#
您不应使用仅通过降级获得的原始 StableHLO(jax.jit(f).lower(1.).compiler_ir()
)进行存档和在另一个进程中进行编译,原因有几个。
首先,编译可能使用不同版本的编译器,支持不同版本的 StableHLO。jax.export
模块通过使用 StableHLO 的可移植工件特性来处理 StableHLO opset 可能的演变,从而解决此问题。
自定义调用的兼容性保证#
其次,原始 StableHLO 可能包含引用 C++ 函数的自定义调用。JAX 使用自定义调用来降级少量原语,例如线性代数原语、分片注释或 Pallas 内核。这些不属于 StableHLO 的兼容性保证范围。这些函数的 C++ 实现很少更改,但它们可能会更改。
jax.export
提供以下导出兼容性保证:JAX 导出的工件可以由编译器和 JAX 运行时系统编译和执行,这些编译器和运行时系统
比用于导出的 JAX 版本新 6 个月以内(我们说 JAX 导出提供 6 个月的向后兼容性)。如果我们想要存档导出的工件以便稍后进行编译和执行,这很有用。
比用于导出的 JAX 版本旧 3 周以内(我们说 JAX 导出提供 3 周的向前兼容性)。如果我们想使用在导出之前构建和部署的使用者(例如,在完成导出时已部署的推理系统)来编译和运行导出的工件,这很有用。
(特定的兼容性窗口长度与 JAX 为 jax2tf 承诺的相同,并且基于 TensorFlow 兼容性。“向后兼容性”术语是从使用者(例如,推理系统)的角度来看的。)
重要的是构建导出和使用组件的时间,而不是导出和编译发生的时间。对于外部 JAX 用户,可以在不同的版本上运行 JAX 和 jaxlib;重要的是构建 jaxlib 版本的时间。
为了减少不兼容的可能性,内部 JAX 用户应
尽可能频繁地重建和重新部署使用者系统.
外部用户应
尽可能使用相同版本的 jaxlib 运行导出和使用者系统,并且
使用最新发布的 jaxlib 版本进行存档导出。
如果您绕过 jax.export
API 来获取 StableHLO 代码,则兼容性保证不适用。
为了确保向前兼容性,当我们更改 JAX 降级规则以使用新的自定义调用目标时,JAX 将在 3 周内避免使用新的目标。要使用最新的降级规则,您可以传递 --jax_export_ignore_forward_compatibility=1
配置标志或 JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1
环境变量。
只有一部分自定义调用保证稳定并且具有兼容性保证(请参阅列表)。我们会不断将更多自定义调用目标添加到允许列表中,并进行向后兼容性测试。如果您尝试序列化调用其他自定义调用目标的代码,则会在导出期间收到错误。
如果您想为特定的自定义调用禁用此安全检查,例如,对于目标 my_target
,您可以将 export.DisabledSafetyCheck.custom_call("my_target")
添加到 export
方法的 disabled_checks
参数中,如下例所示
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> from jax._src import core
>>> from jax._src.interpreters import mlir
>>> # Define a new primitive backed by a custom call
>>> new_prim = core.Primitive("new_prim")
>>> _ = new_prim.def_abstract_eval(lambda x: x)
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
>>> # If we try to export, we get an error
>>> export.export(jax.jit(new_prim.bind))(1.)
Traceback (most recent call last):
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
>>> exp = export.export(
... jax.jit(new_prim.bind),
... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.)
有关确保兼容性的开发人员信息,请参阅 确保向前和向后兼容性。
跨平台和多平台导出#
对于少量 JAX 原语,JAX 降级是特定于平台的。默认情况下,代码会降级并导出到导出计算机上的加速器
>>> from jax import export
>>> export.default_export_platform()
'cpu'
当尝试在没有导出代码的加速器的机器上编译 Exported
对象时,将引发错误的安全性检查。
您可以显式指定应该为哪些平台导出代码。这允许您指定与导出时可用的加速器不同的加速器,甚至允许您指定多平台导出以获取可在多个平台上编译和执行的 Exported
对象。
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)
>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.)
Traceback (most recent call last):
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
>>> # that the code lowered will run adequately on the current
>>> # compilation platform (which is the case for `cos` in this
>>> # example):
>>> exp_unsafe = export.export(jax.jit(lax.cos),
... platforms=['tpu'],
... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)
>>> exp_unsafe.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)
# and similarly with multi-platform lowering
>>> exp_multi = export.export(jax.jit(lax.cos),
... platforms=['tpu', 'cpu', 'cuda'])(1.)
>>> exp_multi.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)
对于多平台导出,StableHLO 将包含多个降级,但仅针对那些需要它的原语,因此生成的模块大小应仅比具有默认导出的模块的大小略大。作为极端情况,当序列化不包含任何具有平台特定降级的原语的模块时,您将获得与单平台导出相同的 StableHLO。
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # A largish function
>>> def f(x):
... for i in range(1000):
... x = jnp.cos(x)
... return x
>>> exp_single = export.export(jax.jit(f))(1.)
>>> len(exp_single.mlir_module_serialized)
9220
>>> exp_multi = export.export(jax.jit(f),
... platforms=["cpu", "tpu", "cuda"])(1.)
>>> len(exp_multi.mlir_module_serialized)
9282
形状多态导出#
当在 JIT 模式下使用时,JAX 将为每种输入形状组合单独跟踪和降级函数。导出时,在某些情况下,可以对某些输入维度使用维度变量,以便获得可用于多种输入形状组合的导出工件。
请参阅 形状多态性 文档。
设备多态导出#
导出的工件可能包含输入、输出和某些中间体的分片注释,但这些注释不直接引用导出时存在的实际物理设备。相反,分片注释引用逻辑设备。这意味着您可以在与用于导出的物理设备不同的物理设备上编译和运行导出的工件。
实现设备多态导出的最干净方法是使用使用 jax.sharding.AbstractMesh
构建的分片,其中仅包含网格形状和轴名称。但是,如果您使用为具有具体设备的网格构建的分片,则可以获得相同的结果,因为网格中的实际设备在跟踪和降级时会被忽略
>>> import jax
>>> from jax import export
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # Use an AbstractMesh for exporting
>>> export_mesh = AbstractMesh((("a", 4),))
>>> def f(x):
... return x.T
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))
>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
4
>>> # and it knows the shardings for the inputs. These will be applied
>>> # when the exported is called.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)
>>> # You can also use a concrete set of devices for exporting
>>> concrete_devices = jax.local_devices()[:4]
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
>>> exp2 = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
... sharding=NamedSharding(concrete_mesh, P("a"))))
>>> # You can expect the same results
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo
>>> # When you call an Exported, you must use a concrete set of devices
>>> arg = jnp.arange(8 * 4)
>>> res1 = exp.call(jax.device_put(arg,
... NamedSharding(concrete_mesh, P("a"))))
>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
'device=TFRT_CPU_1 index=(slice(8, 16, None),)']
>>> # We can call `exp` with some other 4 devices and another
>>> # mesh with a different shape, as long as the number of devices is
>>> # the same.
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
>>> res2 = exp.call(jax.device_put(arg,
... NamedSharding(other_mesh, P("b"))))
>>> # Check out the first 2 shards of the result. Notice that the output is
>>> # sharded similarly; this means that the input was resharded according to the
>>> # exp.in_shardings.
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
'device=TFRT_CPU_3 index=(slice(8, 16, None),)']
尝试使用与导出时不同的设备数量调用导出的工件是错误的
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
... return x.T
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))
>>> arg = jnp.arange(4 * len(export_devices))
>>> exp.call(arg)
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
有一些辅助函数可以使用在调用站点构建的新网格来分片用于调用导出工件的输入
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
... return x.T
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
... sharding=NamedSharding(export_mesh, P("a"))))
>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))
>>> # Shard the arg according to what `exp` expects.
>>> arg = jnp.arange(4 * len(export_devices))
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg)
作为一种特殊功能,如果为 1 个设备导出了一个函数,并且它不包含任何分片注释,则可以在多个设备上分片的相同形状的参数上调用该函数,并且编译器将适当地分片该函数
```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> def f(x):
... return jnp.cos(x)
>>> arg = jnp.arange(4)
>>> exp = export.export(jax.jit(f))(arg)
>>> exp.in_avals
(ShapedArray(int32[4]),)
>>> exp.nr_devices
1
>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))
>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg,
... NamedSharding(calling_mesh, P("b")))
>>> res = exp.call(sharded_arg)
调用约定版本#
JAX 导出支持随着时间的推移而发展,例如,为了支持效果。为了支持兼容性(请参阅 兼容性保证),我们为每个 Exported
维护一个调用约定版本。截至 2024 年 6 月,所有使用版本 9(最新版本,请参阅 所有调用约定版本)导出的函数
>>> from jax import export
>>> exp: export.Exported = export.export(jnp.cos)(1.)
>>> exp.calling_convention_version
9
在任何给定时间,导出 API 可能支持一系列调用约定版本。您可以使用 --jax_export_calling_convention_version
标志或 JAX_EXPORT_CALLING_CONVENTION_VERSION
环境变量来控制使用哪个调用约定版本
>>> from jax import export
>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)
(9, 9)
>>> from jax._src import config
>>> with config.jax_export_calling_convention_version(9):
... exp = export.export(jnp.cos)(1.)
... exp.calling_convention_version
9
我们保留删除对生成或使用旧于 6 个月的调用约定版本支持的权利。
模块调用约定#
Exported.mlir_module
包含一个 main
函数,如果模块支持多个平台(len(platforms) > 1
),该函数会接受一个可选的第一个平台索引参数,后跟对应于有序效果的令牌参数,以及保留的数组参数(对应于 module_kept_var_idx
和 in_avals
)。平台索引是一个 i32 或 i64 标量,它将当前编译平台的索引编码到 platforms
序列中。
内部函数使用不同的调用约定:一个可选的平台索引参数,可选的维度变量参数(类型为 i32 或 i64 的标量张量),后跟可选的令牌参数(在存在有序效果的情况下),以及常规的数组参数。维度参数对应于 args_avals
中出现的维度变量,按照其名称的排序顺序排列。
考虑一个函数的降低,该函数具有一个类型为 f32[w, 2 * h]
的数组参数,其中 w
和 h
是两个维度变量。假设我们使用多平台降低,并且我们有一个有序的效果。main
函数将如下所示:
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
token_in: token,
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return token_out, res
}
实际计算在 _wrapped_jax_export_main
中,该函数还获取 h
和 w
维度变量的值。
_wrapped_jax_export_main
的签名是:
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?]) -> (stablehlo.token, ...)
在调用约定版本 9 之前,效果的调用约定是不同的:main
函数不接收或返回令牌。相反,该函数会创建类型为 i1[0]
的虚拟令牌,并将它们传递给 _wrapped_jax_export_main
。_wrapped_jax_export_main
接收类型为 i1[0]
的虚拟令牌,并将内部创建真正的令牌以传递给内部函数。内部函数使用真正的令牌(在调用约定版本 9 之前和之后)。
此外,从调用约定版本 9 开始,包含平台索引或维度变量值的函数参数具有 jax.global_constant
字符串属性,其值是全局常量的名称,可以是 _platform_index
或维度变量名称。如果未知,全局常量名称可以为空。一些全局常量计算使用内部函数,例如 floor_divide
。此类函数的参数对所有属性都具有 jax.global_constant
属性,这意味着函数的结果也是一个全局常量。
请注意,main
包含对 _check_shape_assertions
的调用。JAX 追踪假定 arg.shape[1]
是偶数,并且 w
和 h
的值都 >= 1。当调用模块时,我们必须检查这些约束。我们使用一个特殊的自定义调用 @shape_assertion
,它接受一个布尔类型的第一个操作数,一个字符串 error_message
属性,该属性可能包含格式说明符 {0}
,{1}
,...,以及对应于格式说明符的可变数量的整数标量操作数。
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2,
error_message="Division had remainder {0} when computing the value of 'h')
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
调用约定版本#
我们在此列出调用约定版本号的历史记录
版本 1 使用 MHLO & CHLO 来序列化代码,不再支持。
版本 2 支持 StableHLO & CHLO。从 2022 年 10 月开始使用。不再支持。
版本 3 支持平台检查和多平台。从 2023 年 2 月开始使用。不再支持。
版本 4 支持具有兼容性保证的 StableHLO。这是 JAX 本地序列化启动时最早的版本。在 JAX 中从 2023 年 3 月 15 日开始使用 (cl/516885716)。从 2023 年 3 月 28 日开始,我们停止使用
dim_args_spec
(cl/520033493)。对该版本的支持于 2023 年 10 月 17 日被删除 (cl/573858283)。版本 5 增加了对
call_tf_graph
的支持。这目前用于某些特殊用途。在 JAX 中从 2023 年 5 月 3 日开始使用 (cl/529106145)。版本 6 增加了对
disabled_checks
属性的支持。此版本要求platforms
属性不为空。自 2023 年 6 月 7 日起由 XlaCallModule 支持,自 2023 年 6 月 13 日起在 JAX 中可用 (JAX 0.4.13)。版本 7 增加了对
stablehlo.shape_assertion
操作以及在disabled_checks
中指定的shape_assertions
的支持。请参阅 存在形状多态性时的错误。自 2023 年 7 月 12 日起由 XlaCallModule 支持 (cl/547482522),自 2023 年 7 月 20 日起在 JAX 序列化中可用 (JAX 0.4.14),并且自 2023 年 8 月 12 日起为默认 (JAX 0.4.15)。版本 8 增加了对
jax.uses_shape_polymorphism
模块属性的支持,并且仅当该属性存在时才启用形状细化传递。自 2023 年 7 月 21 日起由 XlaCallModule 支持 (cl/549973693),自 2023 年 7 月 26 日起在 JAX 中可用 (JAX 0.4.14),并且自 2023 年 10 月 21 日起为默认 (JAX 0.4.20)。版本 9 增加了对效果的支持。有关精确的调用约定,请参阅
export.Exported
的文档字符串。在此调用约定版本中,我们还使用jax.global_constant
属性标记平台索引和维度变量参数。自 2023 年 10 月 27 日起由 XlaCallModule 支持,自 2023 年 10 月 20 日起在 JAX 中可用 (JAX 0.4.20),并且自 2024 年 2 月 1 日起为默认 (JAX 0.4.24)。截至 2024 年 3 月 27 日,这是唯一受支持的版本。
开发者文档#
调试#
您可以记录导出的模块,在 OSS 和 Google 中使用略有不同的标志。在 OSS 中,您可以执行以下操作:
# Log from python
python tests/export_test.py JaxExportTest.test_basic -v=3
# Or, log from pytest to /tmp/mylog.txt
pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
您将看到如下形式的日志行:
I0619 10:54:18.978733 8299482112 _export.py:606] Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
I0619 10:54:18.978767 8299482112 _export.py:607] Define JAX_DUMP_IR_TO to dump the module.
如果您将环境变量 JAX_DUMP_IR_TO
设置为目录,则导出的(以及 JIT 编译的)HLO 模块将保存在那里。
JAX_DUMP_IR_TO=/tmp/export.dumps pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
INFO absl:_export.py:606 Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
INFO absl:_export.py:607 The module was dumped to jax_ir0_jit_sin_export.mlir.
您将看到导出的模块(名为 ..._export.mlir
)和 JIT 编译的模块(名为 ..._compile.mlir
)
$ ls -l /tmp/export.dumps/
total 32
-rw-rw-r--@ 1 necula wheel 2316 Jun 19 11:04 jax_ir0_jit_sin_export.mlir
-rw-rw-r--@ 1 necula wheel 2279 Jun 19 11:04 jax_ir1_jit_sin_compile.mlir
-rw-rw-r--@ 1 necula wheel 3377 Jun 19 11:04 jax_ir2_jit_call_exported_compile.mlir
-rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir
在 Google 内部,您可以使用 --vmodule
参数来指定不同模块的日志记录级别来启用日志记录,例如 --vmodule=_export=3
。
确保向前和向后兼容性#
本节讨论 JAX 开发人员应使用的流程,以确保 兼容性保证。
一个复杂之处在于,外部用户将 JAX 和 jaxlib 安装在单独的包中,并且用户经常最终使用比 JAX 更旧的 jaxlib。我们观察到自定义调用存在于 jaxlib 中,并且只有 jaxlib 与导出工件的消费者相关。为了简化该过程,我们为外部用户设置了期望,即兼容性窗口是根据 jaxlib 版本定义的,并且他们有责任确保即使 JAX 可以使用旧版本,他们也使用新的 jaxlib 进行导出。
因此,我们只关心 jaxlib 版本。当我们发布 jaxlib 版本时,即使我们不强制将其作为允许的最低版本,我们也可以开始向后兼容性弃用时钟。
假设我们需要添加、删除或更改 JAX 降低规则使用的自定义调用目标 T
的语义。以下是可能的时间表(用于更改存在于 jaxlib 中的自定义调用目标):
“D - 1” 天,在更改之前。假设活动的内部 JAX 版本是
0.4.31
(下一个 JAX 和 jaxlib 版本的版本)。JAX 降低规则使用自定义调用T
。“D” 天,我们添加新的自定义调用目标
T_NEW
。我们应该创建一个新的自定义调用目标,并在大约 6 个月后清理旧目标,而不是就地更新T
。请参阅 PR #20997,其中实现了以下步骤的示例。
我们添加自定义调用目标
T_NEW
。我们更改之前使用
T
的 JAX 降低规则,以有条件地使用T_NEW
,如下所示:
from jax._src import config from jax._src.lib import version as jaxlib_version def my_lowering_rule(ctx: LoweringRuleContext, ...): if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31): # this is the old lowering, using target T, while we # are in forward compatibility mode for T, or we # are in OSS and are using an old jaxlib. return hlo.custom_call("T", ...) else: # This is the new lowering, using target T_NEW, for # when we use a jaxlib with version `>= (0, 4, 31)` # (or when this is internal usage), and also we are # in JIT mode. return hlo.custom_call("T_NEW", ...)
请注意,在 JIT 模式下,或者用户传递
--jax_export_ignore_forward_compatibility=true
时,前向兼容模式始终为 false。我们在
_export.py
中将T_NEW
添加到_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE
列表中。
“D + 21” 天(前向兼容窗口结束;可能晚于 21 天):我们移除降级代码中的
forward_compat_mode
,因此只要我们使用新的jaxlib
,导出操作将开始使用新的自定义调用目标T_NEW
。我们为
T_NEW
添加向后兼容性测试。
“RELEASE > D” 天(
D
之后第一个 JAX 发布日期,当我们发布版本0.4.31
时):我们开始计算 6 个月的向后兼容性时间。请注意,这仅在T
是我们已经保证稳定性的自定义调用目标之一时才相关,即在_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE
中列出的目标。如果
RELEASE
位于前向兼容窗口[D, D + 21]
中,并且我们将RELEASE
设置为允许的最低 jaxlib 版本,那么我们可以移除 JIT 分支中的jaxlib_version < (0, 4, 31)
条件。
“RELEASE + 180” 天(向后兼容窗口结束,可能晚于 180 天):到目前为止,我们必须已经提升了最低 jaxlib 版本,以便已经移除了降级条件
jaxlib_version < (0, 4, 31)
,并且 JAX 降级不能生成对T
的自定义调用。我们移除旧的自定义调用目标
T
的 C++ 实现。我们还移除了对
T
的向后兼容性测试。
从 jax.experimental.export 的迁移指南#
在 2024 年 6 月 18 日(JAX 版本 0.4.30),我们弃用了 jax.experimental.export
API,转而使用 jax.export
API。有一些小的更改
jax.experimental.export.export
:旧函数过去允许任何 Python 可调用对象或
jax.jit
的结果。现在只接受后者。你必须在调用export
之前手动将jax.jit
应用于要导出的函数。旧的
lowering_parameters
kwarg 现在命名为platforms
jax.experimental.export.default_lowering_platform()
现在位于jax.export.default_export_platform()
。jax.experimental.export.call
现在是jax.export.Exported
对象的方法。你应该使用exp.call
,而不是export.call(exp)
。jax.experimental.export.serialize
现在是jax.export.Exported
对象的方法。你应该使用exp.serialize()
,而不是export.serialize(exp)
。配置标志
--jax-serialization-version
已弃用。请使用--jax-export-calling-convention-version
。值
jax.experimental.export.minimum_supported_serialization_version
现在位于jax.export.minimum_supported_calling_convention_version
。jax.export.Exported
的以下字段已重命名uses_shape_polymorphism
现在是uses_global_constants
mlir_module_serialization_version
现在是calling_convention_version
lowering_platforms
现在是platforms
。