Jax 和 Jaxlib 版本控制#

为什么 jaxjaxlib 是单独的包?#

我们将 JAX 发布为两个单独的 Python 轮子,即 jax,这是一个纯 Python 轮子,以及 jaxlib,这是一个主要由 C++ 构成的轮子,其中包含以下库:

  • XLA,

  • XLA 使用的 LLVM 部分,

  • MLIR 基础设施,例如 StableHLO Python 绑定。

  • 用于快速 JIT 和 PyTree 操作的 JAX 特定 C++ 库。

我们发布单独的 jaxjaxlib 包,因为这样可以轻松地处理 JAX 的 Python 部分,而无需构建 C++ 代码,甚至无需安装 C++ 工具链。 jaxlib 是一个大型库,许多用户难以构建,但 JAX 的大多数更改仅涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分更新,我们提高了 Python 更改的开发速度。

此外,jaxlib 的构建成本很高,但我们希望能够在 CPU 资源有限的环境中迭代和运行 JAX 测试,例如在 Github Actions 中或在笔记本电脑上。我们的许多 CI 构建只需使用预构建的 jaxlib,而不是需要在每个 PR 上重新构建 JAX 的 C++ 部分。

正如我们将看到的,分别发布 jaxjaxlib 会带来一定的成本,因为它要求对 jaxlib 的更改保持向后兼容的 API。但是,我们认为,总的来说,最好使 Python 更改变得容易,即使是以使 C++ 更改稍微变得更难为代价。

jaxjaxlib 如何进行版本控制?#

总结:jaxjaxlib 在 JAX 源代码树中共享相同的版本号,但作为单独的 Python 包发布。安装时,jax 包版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最小 jaxlib 版本。

jaxjaxlib 的版本号均为 x.y.z 格式,其中 x 为主版本号,y 为次版本号,z 为可选的补丁版本号。版本号必须遵循 PEP 440。版本号比较是整数元组的词典序比较。

每个 jax 版本都关联一个最小的 jaxlib 版本 mx.my.mz。对于 jax 版本 x.y.z,其最小 jaxlib 版本必须不大于 x.y.z

对于 jax 版本 x.y.zjaxlib 版本 lx.ly.lz 兼容,必须满足以下条件:

  • jaxlib 版本 (lx.ly.lz) 必须大于或等于最小 jaxlib 版本 (mx.my.mz)。

  • jax 版本 (x.y.z) 必须大于或等于 jaxlib 版本 (lx.ly.lz)。

这些约束意味着以下发布规则:

  • jax 可以在任何时间独立发布,无需更新 jaxlib

  • 如果发布了新的 jaxlib,则必须同时发布 jax 版本。

这些 版本约束 目前由 jax 在导入时检查,而不是表示为 Python 包版本约束。 jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们 为各种硬件和软件版本(例如,GPU、TPU 等)提供单独的 jaxlib wheel 文件。由于我们不知道哪个是任何给定用户的正确选择,因此我们不希望 pip 自动为我们安装 jaxlib 包。

将来,我们希望将 jaxlib 中特定于硬件的部分分离到单独的插件中,此时最小版本可以表示为 Python 包依赖项。目前,我们确实提供了特定于平台的额外需求,这些需求会安装兼容的 jaxlib 版本,例如 jax[cuda]

如何安全地更改 jaxlib 的 API?#

  • jax 可以在任何时间放弃与旧版 jaxlib 版本的兼容性,只要最小 jaxlib 版本增加到兼容版本即可。但是,请注意,即使对于未发布的 jax 版本,最小 jaxlib 版本也必须是已发布的版本!这使我们能够在 CI 构建中使用已发布的 jaxlib wheel 文件,并允许 Python 开发人员在 HEAD 上处理 jax 而无需构建 jaxlib

    例如,要删除 jax Python 代码中的旧版向后兼容路径,只需增加最小 jaxlib 版本,然后删除兼容性路径即可。

  • jaxlib 可以放弃与低于其自身发布版本号的旧版 jax 版本的兼容性。 jax 强制的版本约束将禁止使用不兼容的 jaxlib

    例如,对于 jaxlib 要删除旧版 jax 版本使用的 Python 绑定 API,必须递增 jaxlib 的次版本号或主版本号。

  • 如果可能,对 jaxlib 的更改应以向后兼容的方式进行。

    通常,jaxlib 可以自由更改其 API,只要遵循关于 jax 兼容所有至少与最小版本一样新的 jaxlib 的规则即可。这意味着 jax 必须始终兼容至少两个版本的 jaxlib,即最后一个版本和树顶版本(实际上是下一个版本)。如果保持兼容性,则更容易做到这一点,尽管可以使用来自 jax 的版本测试进行不兼容的更改;请参见下文。

    例如,通常可以安全地向 jaxlib 添加新函数,但如果当前 jax 仍在使用它,则删除现有函数或更改其签名是不安全的。对 jax 的更改必须对所有大于最小版本到 HEAD 的 jaxlib 版本起作用或优雅降级。

请注意,此处的兼容性规则仅适用于 jaxjaxlib已发布版本。它们不适用于未发布的版本;也就是说,如果从未发布某个 API,或者没有已发布的 jax 版本使用该 API,则可以在 jaxlib 中引入并删除该 API。

jaxlib 的源代码如何布局?#

jaxlib 分散在两个主要存储库中,即 jaxlib/ 主 JAX 存储库中的子目录XLA 源代码树(位于 XLA 存储库中)。XLA 中特定于 JAX 的部分主要位于 xla/python 子目录

JAX 的 C++ 部分(例如 Python 绑定和运行时组件)位于 XLA 树中的原因部分是历史原因,部分是技术原因。

历史原因是,最初 xla/python 绑定被设想为通用 Python 绑定,可能与其他框架共享。在实践中,这种情况越来越少,并且 xla/python 集成了许多特定于 JAX 的部分,并且可能会集成更多部分。因此,最好简单地将 xla/python 视为 JAX 的一部分。

技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保存在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 原子地更新。维护 Python API 的向后和向前兼容性比维护 C++ API 的兼容性更容易,因此 xla/python 公开了 Python API,并负责在 Python 层面保持向后兼容性。

jaxlib 使用 Bazel 从 jax 存储库构建。来自 XLA 存储库的 jaxlib 部分被合并到构建中 作为 Bazel 子模块。要更新构建期间使用的 XLA 版本,必须更新 Bazel WORKSPACE 中的固定版本。这是根据需要手动完成的,但可以在每次构建的基础上覆盖。

如何在发布之间跨 jaxjaxlib 边界进行更改?#

jaxlib 版本是一种粗略的工具:它只允许我们推断发布

然而,由于jaxjaxlib代码分布在无法在单个更改中原子更新的多个代码库中,因此我们需要以比我们的发布周期更细粒度的级别管理兼容性。为了管理细粒度的兼容性,我们采用了独立于jaxlib版本号的其他版本控制方式。

我们在XLA代码库中的xla_client.py中维护了一个额外的版本号(_version)。这个版本号的理念是,它在xla/python中定义,与JAX的C++部分一起,也能够被JAX Python访问,作为jax._src.lib.xla_extension_version,并且每次对XLA/Python代码进行可能影响jax向后兼容性的更改时,都必须递增。然后,JAX Python代码可以使用此版本号来维护向后兼容性,例如:

from jax._src.lib import xla_extension_version

# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

请注意,此版本号是**除了**发布版本号的约束之外的,也就是说,此版本号的存在是为了帮助在开发未发布的代码期间管理兼容性。发布版本也必须遵循上述兼容性规则。