安装#

使用 JAX 需要安装两个包:jax,它是纯 Python 且跨平台的;以及 jaxlib,它包含编译后的二进制文件,并且需要为不同的操作系统和加速器进行不同的构建。

总结: 对于大多数用户来说,典型的 JAX 安装可能如下所示:

  • 仅 CPU(Linux/macOS/Windows)

    pip install -U jax
    
  • GPU (NVIDIA, CUDA 12)

    pip install -U "jax[cuda12]"
    
  • TPU (Google Cloud TPU VM)

    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

支持的平台#

下表显示了所有支持的平台和安装选项。请检查您的设置是否受支持;如果显示“是”“实验性”,请单击相应的链接以了解如何更详细地安装 JAX。

Linux, x86_64

Linux, aarch64

Mac, x86_64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

NVIDIA GPU

不适用

实验性

Google Cloud TPU

不适用

不适用

不适用

不适用

不适用

AMD GPU

实验性

实验性

不适用

Apple GPU

不适用

不适用

实验性

不适用

不适用

Intel GPU

实验性

不适用

不适用

不适用

CPU#

pip 安装:CPU#

目前,JAX 团队为以下操作系统和架构发布 jaxlib wheels:

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Intel

  • macOS, Apple 基于 ARM 的

  • Windows, x86_64(实验性

要安装 JAX 的仅 CPU 版本,这可能对在笔记本电脑上进行本地开发很有用,您可以运行:

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,如果您的计算机上尚未安装,您可能还需要安装 Microsoft Visual Studio 2019 Redistributable

其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能会导致 jaxlib 未与 jax 一起安装,尽管 jax 可能会成功安装(但在运行时会失败)。

NVIDIA GPU#

JAX 支持 SM 版本为 5.2(Maxwell)或更新版本的 NVIDIA GPU。请注意,由于 NVIDIA 已在其软件中放弃对 Kepler GPU 的支持,JAX 不再支持 Kepler 系列 GPU。

您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。

如果您需要使用较新的 CUDA 工具包和较旧的驱动程序,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您或许可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容包

pip 安装:NVIDIA GPU(CUDA,通过 pip 安装,更容易)#

有两种方法可以使用 NVIDIA GPU 支持安装 JAX:

  • 使用通过 pip wheels 安装的 NVIDIA CUDA 和 cuDNN

  • 使用自安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheels 安装 CUDA 和 cuDNN,因为它更容易!

NVIDIA 仅为 x86_64 和 aarch64 发布了 CUDA pip 包;在其他平台上,您必须使用 CUDA 的本地安装。

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,您需要检查以下几项:

  • 确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 可以覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库是 JAX 要求的那些库。重新运行上面的安装命令应该可以解决问题。

pip 安装:NVIDIA GPU(CUDA,本地安装,更困难)#

如果您更喜欢使用预安装的 NVIDIA CUDA 副本,则必须先安装 NVIDIA CUDAcuDNN

JAX 仅为 Linux x86_64 和 Linux aarch64 提供预构建的与 CUDA 兼容的 wheels。其他操作系统和架构的组合是可能的,但需要从源代码构建(请参阅 从源代码构建以了解更多信息)。

您应该使用至少与您的 NVIDIA CUDA 工具包的相应驱动程序版本一样新的 NVIDIA 驱动程序版本。如果您需要使用较新的 CUDA 工具包和较旧的驱动程序,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您或许可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容包

JAX 目前发布一个 CUDA wheel 变体:

构建于

兼容于

CUDA 12.3

CUDA >=12.1

CUDNN 9.1

CUDNN >=9.1, <10.0

NCCL 2.19

NCCL >=2.18

JAX 会检查您的库的版本,如果它们不够新,则会报告错误。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅当您执行多 GPU 计算时才需要。

要安装,请运行:

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"

这些 pip 安装不适用于 Windows,可能会静默失败;请参阅上面的表。

您可以使用以下命令查找您的 CUDA 版本:

nvcc --version

JAX 使用 LD_LIBRARY_PATH 查找 CUDA 库,使用 PATH 查找二进制文件(ptxasnvlink)。请确保这些路径指向正确的 CUDA 安装。

JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 包。请确保它存在于您的 CUDA 安装中。

如果您在预构建的 wheels 中遇到任何错误或问题,请在 GitHub 问题跟踪器上告知 JAX 团队。

NVIDIA GPU Docker 容器#

NVIDIA 提供了 JAX 工具箱容器,这些容器是包含 jax 和一些模型/框架的每晚构建版本的前沿容器。

Google Cloud TPU#

pip 安装:Google Cloud TPU#

JAX 为 Google Cloud TPU 提供预构建的 wheels。要安装 JAX 以及适当版本的 jaxliblibtpu,您可以在您的云 TPU VM 中运行以下命令:

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

对于 Colab (https://colab.research.google.com/) 的用户,请确保您使用的是 TPU v2,而不是较旧的、已弃用的 TPU 运行时。

Mac GPU#

pip 安装#

Apple 提供了一个实验性的 Metal 插件。有关详细信息,请参阅Apple 的 Metal 上 JAX 文档

注意: Metal 插件存在一些注意事项

  • Metal 插件是新的,实验性的,并且存在许多已知问题。请在 JAX 问题跟踪器上报告任何问题。

  • Metal 插件目前需要非常特定版本的 jaxjaxlib。随着插件 API 的成熟,此限制将随着时间的推移而放宽。

AMD GPU (Linux)#

JAX 具有实验性的 ROCm 支持。有两种方法可以安装 JAX:

Intel GPU#

Intel 为 Intel GPU 硬件提供了一个实验性的 OneAPI 插件:intel-extension-for-openxla。有关更多详细信息和安装说明,请参阅以下两种方法之一:

  1. Pip 安装:在 Intel GPU 上加速 JAX

  2. 使用 Intel 的 XLA Docker 容器

请报告与以下相关的所有问题:

Conda(社区支持)#

Conda 安装#

有一个社区支持的 jax Conda 构建版本。要使用 conda 安装它,只需运行:

conda install jax -c conda-forge

如果在具有 NVIDIA GPU 的计算机上运行此命令,则应安装启用了 CUDA 的 jaxlib 包。

要确保您安装的 jax 版本确实启用了 CUDA,请运行:

conda install "jaxlib=*=*cuda*" jax -c conda-forge

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的计算机上安装 CUDA 构建版本,请按照 conda-forge 网站的提示和技巧部分中的说明进行操作。

有关更多详细信息,请转到 conda-forge jaxlibjax 存储库。

JAX 每夜构建版本安装#

每夜构建版本反映的是构建时 JAX 主仓库的状态,可能无法通过完整的测试套件。

与安装 JAX 发布版本的说明不同,这里我们在命令行上显式地命名 JAX 的所有包,因此如果存在更新的版本,pip 将会升级它们。

  • 仅 CPU

pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 12)

pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • NVIDIA GPU (CUDA 12) 旧版

以下用于单体 CUDA jaxlibs 的历史每夜构建版本。 您很可能不需要它;不会再构建更多的单体 CUDA jaxlibs,并且现有的版本将在 2024 年 9 月过期。请使用上面的“CUDA 12”选项。

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

从源代码构建 JAX#

请参阅 从源代码构建

安装旧版 jaxlib wheel 包#

由于 Python 包索引上的存储限制,JAX 团队会定期从 https://pypi.ac.cn/project/jax 上删除较旧的 jaxlib wheel 包。 它们仍然可以通过此处的 URL 直接安装。 例如

# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

对于特定的旧版 GPU wheel 包,请务必使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html