Jax 和 Jaxlib 版本控制#
为什么 jax
和 jaxlib
是独立的包?#
我们将 JAX 发布为两个独立的 Python wheel 包,分别是 jax
(纯 Python wheel 包) 和 jaxlib
(主要由 C++ 组成的 wheel 包,其中包含如下库):
XLA,
XLA 使用的 LLVM 部分,
MLIR 基础设施,例如 StableHLO Python 绑定。
用于快速 JIT 和 PyTree 操作的 JAX 特定的 C++ 库。
我们发布单独的 jax
和 jaxlib
包,是因为这使得在无需构建 C++ 代码甚至无需安装 C++ 工具链的情况下,更容易处理 JAX 的 Python 部分。jaxlib
是一个大型库,对于许多用户来说不容易构建,但 JAX 的大多数更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分进行更新,我们提高了 Python 更改的开发速度。
此外,构建 jaxlib
的成本很高,但我们希望能够在 CPU 资源不足的环境中迭代并运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建只是使用预构建的 jaxlib
,而无需在每个 PR 上重建 JAX 的 C++ 部分。
正如我们将看到的,分别分发 jax
和 jaxlib
会带来一定的成本,因为它要求对 jaxlib
的更改必须保持向后兼容的 API。但是,我们认为,即使以牺牲稍微增加 C++ 更改的难度为代价,使 Python 更改变得容易也是更好的选择。
jax
和 jaxlib
的版本是如何管理的?#
总结:jax
和 jaxlib
在 JAX 源代码树中共享相同的版本号,但作为单独的 Python 包发布。安装时,jax
包的版本必须大于或等于 jaxlib
的版本,并且 jaxlib
的版本必须大于或等于 jax
指定的最小 jaxlib
版本。
jax
和 jaxlib
的发布版本都编号为 x.y.z
,其中 x
是主版本,y
是次版本,z
是可选的补丁版本。版本号必须遵循 PEP 440。版本号比较是对整数元组的字典顺序比较。
每个 jax
发布版本都有一个关联的最小 jaxlib
版本 mx.my.mz
。jax
版本 x.y.z
的最小 jaxlib
版本必须不大于 x.y.z
。
为了使 jax
版本 x.y.z
和 jaxlib
版本 lx.ly.lz
兼容,必须满足以下条件:
jaxlib 版本 (
lx.ly.lz
) 必须大于或等于最小 jaxlib 版本 (mx.my.mz
)。jax 版本 (
x.y.z
) 必须大于或等于 jaxlib 版本 (lx.ly.lz
)。
这些约束意味着以下发布规则:
jax
可以随时单独发布,无需更新jaxlib
。如果发布新的
jaxlib
,则必须同时进行jax
发布。
这些版本约束目前由 jax
在导入时进行检查,而不是表示为 Python 包版本约束。jax
在运行时检查 jaxlib
版本,而不是使用 pip
包版本约束,因为我们为各种硬件和软件版本(例如 GPU、TPU 等)提供单独的 jaxlib
wheel 包。由于我们不知道哪个对任何给定用户来说是正确的选择,因此我们不希望 pip
自动为我们安装 jaxlib
包。
将来,我们希望将 jaxlib
中与硬件相关的部分分离到单独的插件中,届时最小版本可以表示为 Python 包依赖项。现在,我们确实提供了平台特定的额外要求来安装兼容的 jaxlib 版本,例如 jax[cuda]
。
如何安全地修改 jaxlib
的 API?#
jax
可以随时放弃与旧版本jaxlib
的兼容性,只要将最小jaxlib
版本增加到兼容版本即可。但是,请注意,即使对于未发布的jax
版本,最小的jaxlib
也必须是已发布的版本!这允许我们在 CI 构建中使用已发布的jaxlib
wheel 包,并允许 Python 开发人员在无需构建jaxlib
的情况下在 HEAD 上进行jax
的开发。例如,要删除
jax
Python 代码中的旧的向后兼容性路径,只需提高最小 jaxlib 版本,然后删除兼容性路径即可。jaxlib
可以放弃与低于其自身发布版本号的旧版本jax
的兼容性。jax
强制执行的版本约束将禁止使用不兼容的jaxlib
。例如,为了让
jaxlib
删除旧版本jax
使用的 Python 绑定 API,必须增加jaxlib
的次要或主要版本号。如果可能,对
jaxlib
的更改应以向后兼容的方式进行。一般来说,
jaxlib
可以自由地更改其 API,只要遵守关于jax
与所有至少与最小版本一样新的jaxlib
兼容的规则即可。这意味着jax
必须始终与至少两个版本的jaxlib
兼容,即上一个发布版本和树尖版本(实际上是下一个发布版本)。如果保持兼容性,则更容易做到这一点,尽管可以使用jax
中的版本测试进行不兼容的更改;请参见下文。例如,在
jaxlib
中添加新函数通常是安全的,但如果当前jax
仍在 使用它,则删除现有函数或更改其签名是不安全的。对jax
的更改必须适用于所有大于最小版本到 HEAD 的jaxlib
版本,并且可以正常工作或优雅地降级。
请注意,此处的兼容性规则仅适用于 已发布 的 jax
和 jaxlib
版本。它们不适用于未发布的版本;也就是说,如果从未发布或者没有已发布的 jax
版本使用该 API,则可以从 jaxlib
中引入然后再删除 API。
jaxlib
的源代码是如何组织的?#
jaxlib
分布在两个主要存储库中,即主 JAX 存储库中的 jaxlib/
子目录 和 XLA 存储库中的 XLA 源代码树。XLA 中特定于 JAX 的部分主要位于 xla/python
子目录中。
JAX 的 C++ 部分(例如 Python 绑定和运行时组件)位于 XLA 树内的原因部分是历史原因,部分是技术原因。
从历史原因来看,最初 xla/python
绑定被设想为通用的 Python 绑定,可能会与其他框架共享。但在实践中,这种情况越来越少,并且 xla/python
包含许多特定于 JAX 的部分,而且可能会包含更多。因此,最好直接将 xla/python
视为 JAX 的一部分。
从技术原因来看,XLA C++ API 并不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 原子地更新。维护 Python API 的向后和向前兼容性比 C++ API 更容易,因此 xla/python
公开 Python API,并负责在 Python 级别维护向后兼容性。
jaxlib
是使用 Bazel 从 jax
仓库构建的。jaxlib
中来自 XLA 仓库的部分作为 Bazel 子模块被合并到构建过程中。要更新构建期间使用的 XLA 版本,必须更新 Bazel WORKSPACE
中固定的版本。这通常根据需要手动完成,但可以在每次构建时被覆盖。
在发布之间,我们如何在 jax
和 jaxlib
边界之间进行更改?#
jaxlib 版本是一个粗略的工具:它只允许我们推断发布版本。
然而,由于 jax
和 jaxlib
代码分布在不能在单个更改中原子更新的仓库中,我们需要以比我们的发布周期更精细的粒度来管理兼容性。为了管理细粒度的兼容性,我们有独立于 jaxlib
发布版本号的其他版本控制。
我们在 XLA 仓库中的 xla_client.py
中维护一个额外的版本号(_version
)。其想法是,这个版本号在 xla/python
中与 JAX 的 C++ 部分一起定义,也可以作为 jax._src.lib.xla_extension_version
被 JAX Python 访问,并且每次对 XLA/Python 代码进行更改时,如果对 jax
的向后兼容性有影响,则必须递增此版本号。然后,JAX Python 代码可以使用此版本号来维护向后兼容性,例如:
from jax._src.lib import xla_extension_version
# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
请注意,此版本号是对已发布版本号的约束的补充,也就是说,此版本号的存在是为了在未发布代码的开发过程中帮助管理兼容性。发布版本还必须遵循上述兼容性规则。