Jax 和 Jaxlib 版本控制#

为什么 jaxjaxlib 是单独的包?#

我们将 JAX 发布为两个独立的 Python wheel 包,即 jax(一个纯 Python wheel 包)和 jaxlib(一个主要由 C++ 编写的 wheel 包,包含诸如以下库):

  • XLA,

  • XLA 使用的 LLVM 部分,

  • MLIR 基础设施,例如 StableHLO Python 绑定。

  • 用于快速 JIT 和 PyTree 操作的 JAX 特定 C++ 库。

我们分发独立的 jaxjaxlib 包,因为它使得在无需构建 C++ 代码甚至无需安装 C++ 工具链的情况下,轻松处理 JAX 的 Python 部分成为可能。jaxlib 是一个大型库,对许多用户来说构建起来并不容易,但 JAX 的大多数更改都只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分进行更新,我们提高了 Python 更改的开发速度。

此外,jaxlib 的构建成本很高,但我们希望能够在没有大量 CPU 的环境中迭代和运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建只是使用预构建的 jaxlib,而不是在每个 PR 上都需要重建 JAX 的 C++ 部分。

正如我们将看到的,分别分发 jaxjaxlib 会带来一定的代价,因为它要求对 jaxlib 的更改保持向后兼容的 API。然而,我们相信,即使以使 C++ 更改稍微困难为代价,使 Python 更改变得容易也是更可取的。

jaxjaxlib 如何进行版本控制?#

总结:jaxjaxlib 在 JAX 源代码树中共享相同的版本号,但作为单独的 Python 包发布。安装时,jax 包的版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最小 jaxlib 版本。

jaxjaxlib 的版本号均为 x.y.z,其中 x 是主版本号,y 是次版本号,z 是可选的补丁版本号。版本号必须遵循 PEP 440。版本号比较是对整数元组的字典顺序比较。

每个 jax 版本都有一个关联的最小 jaxlib 版本 mx.my.mzjax 版本 x.y.z 的最小 jaxlib 版本必须不大于 x.y.z

为了使 jax 版本 x.y.zjaxlib 版本 lx.ly.lz 兼容,必须满足以下条件:

  • jaxlib 版本(lx.ly.lz)必须大于或等于最小 jaxlib 版本(mx.my.mz)。

  • jax 版本(x.y.z)必须大于或等于 jaxlib 版本(lx.ly.lz)。

这些约束意味着以下发布规则:

  • 可以随时单独发布 jax,而无需更新 jaxlib

  • 如果发布新的 jaxlib,则必须同时发布一个 jax 版本。

这些版本约束目前在导入时由 jax 检查,而不是表示为 Python 包版本约束。jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们为各种硬件和软件版本(例如,GPU、TPU 等)提供单独的 jaxlib wheel 包。由于我们不知道哪个是任何给定用户的正确选择,因此我们不希望 pip 自动为我们安装 jaxlib 包。

将来,我们希望将 jaxlib 中特定于硬件的部分分离到单独的插件中,到那时,最小版本可以表示为 Python 包依赖项。目前,我们确实提供了特定于平台的额外要求,这些要求会安装兼容的 jaxlib 版本,例如,jax[cuda]

如何安全地更改 jaxlib 的 API?#

  • jax 可以随时放弃与旧版本 jaxlib 的兼容性,只要将最小 jaxlib 版本增加到兼容版本即可。但是,请注意,即使对于未发布的 jax 版本,最小 jaxlib 也必须是已发布的版本!这允许我们在 CI 构建中使用已发布的 jaxlib wheel 包,并允许 Python 开发人员在 HEAD 上处理 jax,而无需构建 jaxlib

    例如,要删除 jax Python 代码中的旧向后兼容路径,只需提高最小 jaxlib 版本,然后删除兼容路径即可。

  • jaxlib 可以放弃与低于其自身发布版本号的旧版本 jax 的兼容性。 jax 强制执行的版本约束将禁止使用不兼容的 jaxlib

    例如,为了使 jaxlib 删除旧版本 jax 使用的 Python 绑定 API,必须增加 jaxlib 的次版本号或主版本号。

  • 如果可能,应以向后兼容的方式对 jaxlib 进行更改。

    一般来说,jaxlib 可以自由更改其 API,只要遵循 jax 与所有至少与最小版本一样新的 jaxlib 兼容的规则即可。这意味着 jax 必须始终与至少两个版本的 jaxlib 兼容,即,上次发布版本和树顶版本,实际上是下一个发布版本。如果保持兼容性,则更容易做到这一点,尽管可以使用 jax 的版本测试进行不兼容的更改;请参见下文。

    例如,通常在 jaxlib 中添加新函数是安全的,但如果当前的 jax 仍在调用现有函数,则删除现有函数或更改其签名是不安全的。jax 的更改必须适用于从最小版本到 HEAD 的所有 jaxlib 版本,或者正常降级。

请注意,此处的兼容性规则仅适用于 jaxjaxlib 的 *已发布* 版本。它们不适用于未发布的版本;也就是说,如果从未发布 jaxlib 的 API,或者如果没有已发布的 jax 版本使用该 API,则可以从 jaxlib 中引入然后删除 API。

jaxlib 的源代码是如何布局的?#

jaxlib 分布在两个主要存储库中,即主 JAX 存储库中的 jaxlib/ 子目录XLA 源代码树(位于 XLA 存储库内)。XLA 中特定于 JAX 的部分主要位于 xla/python 子目录中。

JAX 的 C++ 部分(例如 Python 绑定和运行时组件)位于 XLA 树中的原因是部分历史原因和部分技术原因。

历史原因是,最初设想 xla/python 绑定是通用的 Python 绑定,可以与其他框架共享。但在实践中,这种情况越来越少,xla/python 包含了一些特定于 JAX 的部分,并且可能会包含更多。因此,最好直接将 xla/python 视为 JAX 的一部分。

技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 原子地更新。维护 Python API 的向后和向前兼容性比 C++ API 更容易,因此 xla/python 公开 Python API,并负责在 Python 级别维护向后兼容性。

jaxlib 是使用 Bazel 从 jax 仓库构建的。jaxlib 中来自 XLA 仓库的部分作为 Bazel 子模块被纳入构建过程中。要更新构建期间使用的 XLA 版本,必须更新 Bazel WORKSPACE 中固定的版本。这是根据需要手动完成的,但可以在每次构建的基础上进行覆盖。

在发布之间,我们如何在 jaxjaxlib 边界之间进行更改?#

jaxlib 版本是一个粗略的工具:它只允许我们推理发布

但是,由于 jaxjaxlib 代码分布在不同的仓库中,无法在单个更改中原子地更新,我们需要以比发布周期更精细的粒度来管理兼容性。为了管理细粒度的兼容性,我们有独立于 jaxlib 发布版本号的额外版本控制。

我们在 XLA 仓库的 xla_client.py 中维护一个额外的版本号 (_version)。想法是,这个版本号在 xla/python 中与 JAX 的 C++ 部分一起定义,并且 JAX Python 也可以通过 jax._src.lib.xla_extension_version 访问,并且每次对 XLA/Python 代码进行更改,而这些更改对 jax 具有向后兼容性影响时,都必须递增。然后,JAX Python 代码可以使用此版本号来维护向后兼容性,例如:

from jax._src.lib import xla_extension_version

# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

请注意,此版本号是除了已发布版本号的约束之外的,也就是说,此版本号的存在是为了帮助管理未发布代码的开发期间的兼容性。发布也必须遵循上面给出的兼容性规则。