从源代码构建#

首先,获取 JAX 源代码

git clone https://github.com/jax-ml/jax
cd jax

构建 JAX 涉及两个步骤

  1. 构建或安装 jaxlib,它是 jax 的 C++ 支持库。

  2. 安装 jax Python 包。

构建或安装 jaxlib#

使用 pip 安装 jaxlib#

如果您只修改 JAX 的 Python 部分,我们建议使用 pip 从预构建的 wheel 安装 jaxlib

pip install jaxlib

有关 pip 安装的完整指南(例如,对于 GPU 和 TPU 支持),请参阅 JAX 自述文件

从源代码构建 jaxlib#

警告

虽然通常可以使用大多数现代编译器从源代码编译 jaxlib,但构建仅使用 clang 进行测试。欢迎提交 Pull Request 以改进对不同工具链的支持,但其他编译器未获得积极支持。

要从源代码构建 jaxlib,您还必须安装一些先决条件

  • C++ 编译器

    如上面的框中所述,最好使用最新版本的 clang(在撰写本文时,我们测试的版本是 18),但其他编译器(例如 g++ 或 MSVC)也可能有效。

    在 Ubuntu 或 Debian 上,您可以按照 LLVM 文档中的说明安装最新稳定版本的 clang。

    如果您在 Mac 上进行构建,请确保已安装 XCode 和 XCode 命令行工具。

    有关 Windows 构建说明,请参阅下文。

  • Python:用于运行构建辅助脚本。请注意,无需在本地安装 Python 依赖项,因为在构建过程中将忽略您的系统 Python;请查看 管理密封 Python 了解详细信息。

要为 CPU 或 TPU 构建 jaxlib,您可以运行

python build/build.py build --wheels=jaxlib --verbose
pip install dist/*.whl  # installs jaxlib (includes XLA)

要为与当前系统安装不同的 Python 版本构建 wheel,请将 --python_version 标志传递给构建命令

python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose

本文档的其余部分假设您正在为与当前系统安装匹配的 Python 版本进行构建。如果您需要为不同的版本进行构建,只需在每次调用 python build/build.py 时附加 --python_version=<py version> 标志。请注意,无论是否传递 --python_version 参数,Bazel 构建都将始终使用密封的 Python 安装。

如果您想构建 jaxlib 和 CUDA 插件:请运行

python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

生成三个 wheel(不带 cuda 的 jaxlib、jax-cuda-plugin 和 jax-cuda-pjrt)。默认情况下,所有 CUDA 编译步骤都由 NVCC 和 clang 执行,但可以通过 --build_cuda_with_clang 标志将其限制为 clang。

有关配置选项,请参阅 python build/build.py --help。此处的 python 应该是您的 Python 3 解释器的名称;在某些系统上,您可能需要使用 python3 代替。尽管使用 python 调用脚本,Bazel 将始终使用其自己的密封 Python 解释器和依赖项,只有 build/build.py 脚本本身将由您的系统 Python 解释器处理。默认情况下,wheel 将写入当前目录的 dist/ 子目录。

  • 从 v.0.4.32 开始的 JAX 版本:您可以在配置选项中提供自定义 CUDA 和 CUDNN 版本。Bazel 将下载它们并用作目标依赖项。

    要下载特定版本的 CUDA/CUDNN 再分发,您可以使用 --cuda_version--cudnn_version 标志

    python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    请注意,这些参数是可选的:默认情况下,Bazel 将下载 .bazelrc 中环境变量 HERMETIC_CUDA_VERSIONHERMETIC_CUDNN_VERSION 中提供的 CUDA 和 CUDNN 再分发版本。

    要指向本地文件系统上的 CUDA/CUDNN/NCCL 再分发,您可以使用以下命令

    python build/build.py build --wheels=jax-cuda-plugin \
    --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \
    --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \
    --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"
    

    请参阅 XLA 文档 中的完整说明列表。

  • v.0.4.32 之前的 JAX 版本:您必须安装 CUDA 和 CUDNN,并使用配置选项提供其路径。

使用修改后的 XLA 存储库从源代码构建 jaxlib。#

JAX 依赖于 XLA,其源代码位于 XLA GitHub 存储库中。默认情况下,JAX 使用 XLA 存储库的固定副本,但在处理 JAX 时,我们通常希望使用本地修改的 XLA 副本。有两种方法可以实现此目的

  • 使用 Bazel 的 override_repository 功能,您可以将其作为命令行标志传递给 build.py,如下所示

    python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
    
  • 修改 JAX 源代码树根目录中的 WORKSPACE 文件,以指向不同的 XLA 树。

要将更改贡献回 XLA,请将 PR 发送到 XLA 存储库。

JAX 固定的 XLA 版本会定期更新,但在每次 jaxlib 发布之前会特别更新。

在 Windows 上从源代码构建 jaxlib 的其他说明#

注意:JAX 不支持 Windows 上的 CUDA;请使用 WSL2 获取 CUDA 支持。

在 Windows 上,请按照 安装 Visual Studio 设置 C++ 工具链。需要 Visual Studio 2019 版本 16.5 或更高版本。

JAX 构建使用符号链接,这需要您激活开发人员模式

您可以使用其 Windows 安装程序安装 Python,或者如果您愿意,可以使用 AnacondaMiniconda 设置 Python 环境。

Bazel 的某些目标使用 bash 实用程序执行脚本,因此需要 MSYS2。有关更多详细信息,请参阅 在 Windows 上安装 Bazel。安装以下软件包

pacman -S patch coreutils

安装 coreutils 后,realpath 命令应存在于 shell 的路径中。

安装所有内容后。打开 PowerShell,并确保 MSYS2 位于当前会话的路径中。确保 bazelpatchrealpath 可以访问。激活 conda 环境。

python .\build\build.py build --wheels=jaxlib

要使用调试信息进行构建,请添加标志 --bazel_options='--copt=/Z7'

为 AMD GPU 构建 ROCM jaxlib 的其他说明#

有关使用 ROCm 支持构建 jaxlib 的详细说明,请参阅官方指南:从源代码构建 ROCm JAX

管理密封 Python#

为了确保 JAX 的构建是可重现的,在支持的平台(Linux、Windows、MacOS)上行为一致,并且与本地系统的具体情况正确隔离,我们依赖于密封的 Python(由 rules_python 提供,有关详细信息,请参阅 工具链注册),用于通过 Bazel 执行的所有构建和测试命令。这意味着在构建过程中将忽略您的系统 Python 安装,并且 Python 解释器本身以及所有 Python 依赖项将由 bazel 直接管理。

指定 Python 版本#

当您运行 build/build.py 工具时,密封 Python 的版本会自动设置为与您用于运行 build/build.py 脚本的 Python 版本匹配。要显式选择特定版本,您可以将 --python_version 参数传递给该工具

python build/build.py build --python_version=3.12

在底层,密封 Python 版本由 HERMETIC_PYTHON_VERSION 环境变量控制,该变量在您运行 build/build.py 时会自动设置。如果您直接运行 bazel,您可能需要通过以下方式之一显式设置该变量

# Either add an entry to your `.bazelrc` file
build --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR pass it directly to your specific build command
bazel build <target> --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR set the environment variable globally in your shell:
export HERMETIC_PYTHON_VERSION=3.12

您可以在同一台机器上按顺序针对不同的 Python 版本运行构建和测试,只需在运行之间切换 --python_version 的值即可。先前构建中所有与 python 无关的构建缓存都将保留并重用于后续构建。

指定 Python 依赖项#

在 bazel 构建期间,所有 JAX 的 Python 依赖项都固定为特定版本。这对于确保构建的可重现性是必要的。JAX 依赖项的完整传递闭包的固定版本及其相应的哈希值在 build/requirements_lock_<python version>.txt 文件中指定(例如,build/requirements_lock_3_12.txt 用于 Python 3.12)。

要更新锁定文件,请确保 build/requirements.in 包含所需的直接依赖项列表,然后执行以下命令(该命令将在后台调用 pip-compile

python build/build.py requirements_update --python_version=3.12

或者,如果您需要更多控制,可以直接运行 bazel 命令(这两个命令是等效的)

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

其中 3.12 是您希望更新的 Python 版本。

请注意,由于底层仍然使用 pippip-compile 工具,因此这些工具支持的大部分命令行参数和功能也将被 Bazel 需求更新器命令识别。例如,如果您希望更新器考虑预发布版本,只需将 --pre 参数传递给 bazel 命令

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- --pre

指定本地 wheel 文件的依赖项#

默认情况下,构建会扫描仓库根目录下的 dist 目录,以查找要包含在依赖项列表中的任何本地 .whl 文件。如果 wheel 文件是 Python 版本特定的,则只会包含与所选 Python 版本匹配的 wheel 文件。

整体本地 wheel 搜索和选择逻辑由 python_init_repositories() 宏(直接从 WORKSPACE 文件调用)的参数控制。您可以使用 local_wheel_dist_folder 来更改本地 wheel 文件所在的文件夹位置。使用 local_wheel_inclusion_listlocal_wheel_exclusion_list 参数来指定应该包含和/或排除哪些 wheel 文件(它支持基本的通配符匹配)。

如有必要,您还可以手动依赖本地 .whl 文件,从而绕过自动本地 wheel 搜索机制。例如,要依赖您新构建的 jaxlib wheel 文件,您可以在 build/requirements.in 中添加 wheel 文件的路径,并为选定的 Python 版本重新运行需求更新器命令。例如

echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

指定 nightly wheel 文件的依赖项#

为了构建和测试最新的、可能不稳定的 Python 依赖项集,我们提供了以下依赖项更新器命令的特殊版本

python build/build.py requirements_update --python_version=3.12 --nightly_update

或者,如果您直接运行 bazel(这两个命令是等效的)

bazel run //build:requirements_nightly.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

此命令与常规更新器的区别在于,默认情况下它会接受预发布、开发和 nightly 包,它还将搜索 https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 作为额外的索引 URL,并且不会在生成的需求锁定文件中放置哈希值。

自定义 hermetic Python (高级用法)#

我们支持所有当前的 Python 版本,因此除非您的工作流程有非常特殊的要求(例如使用您自己的自定义 Python 解释器的能力),否则您可以安全地完全跳过本节。

简而言之,如果您依赖于非标准的 Python 工作流程,您仍然可以在 hermetic Python 设置中实现高度的灵活性。从概念上讲,与非 hermetic 的情况相比,只会有一个区别:您需要从文件的角度思考,而不是安装(即思考您的构建实际依赖哪些文件,而不是需要在您的系统上安装哪些文件),其余的都差不多。

因此,在实践中,要完全控制您的 Python 环境,无论是否是 hermetic 的,您都需要能够执行以下三件事

  1. 指定要使用的 python 解释器(即选择实际的 pythonpython3 二进制文件以及与之位于同一文件夹中的库)。

  2. 指定 Python 依赖项列表(例如 numpy)及其各自的版本。

  3. 能够轻松地在列表中添加/删除/更新依赖项。每个依赖项本身也可以是自定义的(例如,自行构建的)。

您已经知道如何在非 hermetic Python 环境中执行上述所有步骤,以下是在 hermetic 环境中执行相同操作的方法(从文件而不是安装的角度来处理)

  1. 与其安装 Python,不如将 Python 解释器放入 tarzip 文件中。根据您的情况,您可以简单地拉取许多现有的解释器之一(例如 python-build-standalone),或者构建您自己的解释器并将其打包到存档中(遵循官方的 构建说明即可)。例如,在 Linux 上,它看起来像这样

    ./configure --prefix python
    make -j12
    make altinstall
    tar -czpf my_python.tgz python
    

    准备好 tarball 后,通过将 HERMETIC_PYTHON_URL 环境变量指向该存档(本地的或来自互联网的)来将其插入到构建中

    --repo_env=HERMETIC_PYTHON_URL="file:///local/path/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # OR
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # We assume that top-level folder in the tarbal is called "python", if it is
    # something different just pass additional HERMETIC_PYTHON_PREFIX parameter
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    --repo_env=HERMETIC_PYTHON_PREFIX="my_python/install"
    
  2. 与其执行 pip install,不如创建 requirements_lock.txt 文件,其中包含依赖项的完整传递闭包。您还可以依赖此仓库中已有的现有依赖项(只要它们与您的自定义 Python 版本一起使用)。关于如何执行此操作,没有特殊说明,您可以按照本文档中 指定 Python 依赖项 中推荐的步骤进行操作,只需直接调用 pip-compile 即可(请注意,锁定文件必须是 hermetic 的,但如果您愿意,也可以从非 hermetic 的 python 生成它),甚至可以手动创建它(请注意,哈希值在锁定文件中是可选的)。

  3. 如果您需要更新或自定义您的依赖项列表,您可以再次按照 指定 Python 依赖项 说明来更新 requirements_lock.txt,直接调用 pip-compile 或手动修改它。如果您有一个想要使用的自定义包,只需从您的锁定文件中直接指向它的 .whl 文件(请记住,要从文件而不是安装的角度考虑)(请注意,requirements.txtrequirements_lock.txt 文件支持本地 wheel 引用)。如果您的 requirements_lock.txt 已经在 WORKSPACE 文件中被指定为 python_init_repositories() 的依赖项,则您无需执行任何其他操作。否则,您可以按如下方式指向您的自定义文件

    --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/custom_requirements_lock.txt"
    

    另请注意,如果您使用 HERMETIC_REQUIREMENTS_LOCK,则它将完全控制您的依赖项列表,并且 指定本地 wheel 文件的依赖项 中描述的自动本地 wheel 文件解析逻辑将被禁用,以免干扰它。

就是这样。总结一下:如果您有一个包含 Python 解释器的存档文件和一个包含依赖项的完整传递闭包的 requirements_lock.txt 文件,那么您就可以完全控制您的 Python 环境。

自定义 hermetic Python 示例#

请注意,对于以下所有示例,您还可以全局设置环境变量(即在您的 shell 中使用 export,而不是将 --repo_env 参数传递给您的命令),这样通过 build/build.py 调用 bazel 也可以正常工作。

使用来自互联网的自定义 Python 3.13 进行构建,使用此仓库中已有的默认 requirements_lock_3_13.txt(即自定义解释器,但默认依赖项)

bazel build <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="https://github.com/indygreg/python-build-standalone/releases/download/20241016/cpython-3.13.0+20241016-x86_64-unknown-linux-gnu-install_only.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256="2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4"

使用来自本地文件系统的自定义 Python 3.13 和自定义锁定文件进行构建(假设锁定文件在运行命令之前已放入此仓库的 jax/build 文件夹中)

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_PREFIX="prefix/to/strip/in/cython/tar/gz/archive"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

如果默认的 Python 解释器对您来说足够好,而您只需要一组自定义的依赖项

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

请注意,您可以有多个与同一 Python 版本对应的 requirement_lock.txt 文件,以支持不同的场景。您可以通过指定 HERMETIC_PYTHON_VERSION 来控制选择哪个文件。例如,在 WORKSPACE 文件中

requirements = {
  "3.10": "//build:requirements_lock_3_10.txt",
  "3.11": "//build:requirements_lock_3_11.txt",
  "3.12": "//build:requirements_lock_3_12.txt",
  "3.13": "//build:requirements_lock_3_13.txt",
  "3.13-scenario1": "//build:scenario1_requirements_lock_3_13.txt",
  "3.13-scenario2": "//build:scenario2_requirements_lock_3_13.txt",
},

然后,您可以在不更改环境中任何内容的情况下构建和测试不同的组合

# To build with scenario1 dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1

# To build with scenario2 dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2

# To build with default dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13

# To build with scenario1 dependendencies and custom Python 3.13 interpreter:
bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>

安装 jax#

安装 jaxlib 后,您可以通过运行以下命令来安装 jax

pip install -e .  # installs jax

要从 GitHub 升级到最新版本,只需从 JAX 仓库根目录运行 git pull,然后通过运行 build.py 或在必要时升级 jaxlib 来重新构建。您不必重新安装 jax,因为 pip install -e 会在 site-packages 中设置指向仓库的符号链接。

运行测试#

有两种支持的机制来运行 JAX 测试,可以使用 Bazel 或 pytest。

使用 Bazel#

首先,使用 --configure_only 标志配置 JAX 构建。对于 CPU 测试,传递 --wheel_list=jaxlib;对于 GPU 测试,传递 CUDA/ROCM。

python build/build.py build --wheels=jaxlib --configure_only
python build/build.py build --wheels=jax-cuda-plugin --configure_only
python build/build.py build --wheels=jax-rocm-plugin --configure_only

您可以将其他选项传递给 build.py 以配置构建;有关详细信息,请参阅 jaxlib 构建文档。

默认情况下,Bazel 构建使用从源代码构建的 jaxlib 运行 JAX 测试。要运行 JAX 测试,请运行

bazel test //tests:cpu_tests //tests:backend_independent_tests

如果您有必要的硬件,还可以使用 //tests:gpu_tests//tests:tpu_tests

要使用预安装的 jaxlib 而不是构建它,您首先需要使其在封闭的 Python 中可用。要在封闭的 Python 中安装特定版本的 jaxlib,请运行(以 jaxlib >= 0.4.26 为例)

echo -e "\njaxlib >= 0.4.26" >> build/requirements.in
python build/build.py requirements_update

或者,要从本地 wheel 安装 jaxlib(假设 Python 3.12)

echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

一旦您在封闭的环境中安装了 jaxlib,请运行

bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests

可以使用环境变量(见下文)控制许多测试行为。可以使用 Bazel 的 --test_env=FLAG=value 标志将环境变量传递给 JAX 测试。

一些 JAX 测试用于多个加速器(即 GPU、TPU)。当 JAX 已经安装时,您可以像这样运行 GPU 测试

bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform

您可以通过在多个加速器上并行运行单加速器测试来加速它们。这也会触发每个加速器的多个并发测试。对于 GPU,您可以这样做

NB_GPUS=2
JOBS_PER_ACC=4
J=$((NB_GPUS * JOBS_PER_ACC))
MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J"
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU

使用 pytest#

首先,通过运行 pip install -r build/test-requirements.txt 安装依赖项。

要使用 pytest 运行所有 JAX 测试,我们建议使用 pytest-xdist,它可以并行运行测试。它作为 pip install -r build/test-requirements.txt 命令的一部分安装。

从存储库根目录运行

pytest -n auto tests

控制测试行为#

JAX 以组合方式生成测试用例,您可以使用 JAX_NUM_GENERATED_CASES 环境变量控制为每个测试生成和检查的用例数量(默认为 10)。自动化测试目前默认使用 25。

例如,可以编写

# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`

# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

自动化测试还会使用默认的 64 位浮点数和整数 (JAX_ENABLE_X64) 运行测试。

JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

您可以使用 pytest 的内置选择机制运行更具体的测试集,或者,您可以直接运行特定的测试文件以查看有关正在运行的用例的更详细信息

JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py

您可以通过传递环境变量 JAX_SKIP_SLOW_TESTS=1 来跳过一些已知速度较慢的测试。

要从测试文件中指定要运行的特定测试集,您可以通过 --test_targets 标志传递字符串或正则表达式。例如,您可以使用以下方法运行 jax.numpy.pad 的所有测试

python tests/lax_numpy_test.py --test_targets="testPad"

Colab 笔记本会作为文档构建的一部分进行错误测试。

假设测试#

一些测试使用 hypothesis。通常,hypothesis 将使用多个示例输入进行测试,并且在测试失败时,它将尝试找到仍然导致失败的更小示例:在测试失败中查找如下所示的行,并添加消息中提到的装饰器

You can reproduce this example by temporarily adding @reproduce_failure('6.97.4', b'AXicY2DAAAAAEwAB') as a decorator on your test case

对于交互式开发,您可以设置环境变量 JAX_HYPOTHESIS_PROFILE=interactive (或等效标志 --jax_hypothesis_profile=interactive)以将示例数设置为 1,并跳过示例最小化阶段。

文档测试#

JAX 使用 pytest 的文档测试模式来测试文档中的代码示例。您可以在 ci-build.yaml 中找到运行文档测试的最新命令。例如,您可以运行

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md

此外,JAX 在 doctest-modules 模式下运行 pytest,以确保函数文档字符串中的代码示例可以正确运行。您可以在本地使用例如以下命令运行

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py

类型检查#

我们使用 mypy 来检查类型提示。要使用与 github CI 检查相同的配置运行 mypy,您可以使用 pre-commit 框架

pip install pre-commit
pre-commit run mypy --all-files

由于 mypy 在检查所有文件时可能会比较慢,因此仅检查您修改过的文件可能会更方便。为此,请先暂存更改(即 git add 更改的文件),然后在提交更改之前运行此操作

pre-commit run mypy

代码检查#

JAX 使用 ruff 代码检查器来确保代码质量。要使用与 github CI 检查相同的配置运行 ruff,您可以使用 pre-commit 框架

pip install pre-commit
pre-commit run ruff --all-files

更新文档#

要重建文档,请安装几个包

pip install -r docs/requirements.txt

然后运行

sphinx-build -b html docs docs/build/html -j auto

这可能需要很长时间,因为它会执行文档源中的许多笔记本;如果您希望在不执行笔记本的情况下构建文档,则可以运行

sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto

然后,您可以在 docs/build/html/index.html 中查看生成的文档。

-j auto 选项控制构建的并行性。您可以使用数字代替 auto 来控制要使用的 CPU 核心数。

更新笔记本#

我们使用 jupytextdocs/notebooks 中维护笔记本的两个同步副本:一个采用 ipynb 格式,另一个采用 md 格式。前者的优点是可以在 Colab 中直接打开和执行;后者的优点是它使在版本控制中跟踪差异变得更加容易。

编辑 ipynb#

对于对代码和输出进行大幅修改的较大更改,在 Jupyter 或 Colab 中编辑笔记本最容易。要在 Colab 界面中编辑笔记本,请打开 http://colab.research.google.com,然后从本地存储库中 Upload。根据需要进行更新,Run all cells,然后 Download ipynb。您可能需要使用如上所述的 sphinx-build 测试它是否正确执行。

编辑 md#

对于对笔记本的文本内容进行较小的更改,使用文本编辑器编辑 .md 版本最容易。

同步笔记本#

在编辑笔记本的 ipynb 或 md 版本之后,您可以使用 jupytext 通过在更新的笔记本上运行 jupytext --sync 来同步这两个版本;例如

pip install jupytext==1.16.4
jupytext --sync docs/notebooks/thinking_in_jax.ipynb

jupytext 版本应与 .pre-commit-config.yaml 中指定的版本匹配。

要检查 markdown 和 ipynb 文件是否已正确同步,您可以使用 pre-commit 框架来执行 github CI 使用的相同检查

pip install pre-commit
pre-commit run jupytext --all-files

创建新笔记本#

如果您要向文档添加新笔记本并希望使用此处讨论的 jupytext --sync 命令,则可以使用以下命令为 jupytext 设置笔记本

jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb

此操作通过向笔记本文件添加一个 "jupytext" 元数据字段来指定所需的格式,并且当调用 jupytext --sync 命令时,该字段会被识别。

Sphinx 构建中的笔记本#

一些笔记本(notebook)会自动构建,作为预提交检查的一部分和 Read the docs 构建的一部分。如果单元格引发错误,构建将会失败。如果这些错误是有意的,你可以捕获它们,或者使用 raises-exceptions 元数据标记单元格(示例 PR)。你必须在 .ipynb 文件中手动添加此元数据。当其他人重新保存笔记本时,它将被保留。

我们从构建中排除了一些笔记本,例如,因为它们包含长时间的计算。请参阅 conf.py 中的 exclude_patterns

readthedocs.io 上构建文档#

JAX 的自动生成文档位于 https://jax.ac.cn/

整个项目的文档构建由 readthedocs JAX 设置控制。当前设置会在代码被推送到 GitHub main 分支后立即触发文档构建。对于每个代码版本,构建过程由 .readthedocs.ymldocs/conf.py 配置文件驱动。

对于每个自动文档构建,你可以查看文档构建日志

如果你想在 Readthedocs 上测试文档生成,可以将代码推送到 test-docs 分支。该分支也会自动构建,你可以在这里查看生成的文档。如果文档构建失败,你可能需要清除 test-docs 的构建环境

对于本地测试,我能够在新的目录中通过重放我在 Readthedocs 日志中看到的命令来完成。

mkvirtualenv jax-docs  # A new virtualenv
mkdir jax-docs  # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
workon jax-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html