安装#

使用 JAX 需要安装两个软件包:jax,它是纯 Python 且跨平台的;以及 jaxlib,它包含编译后的二进制文件,并且需要为不同的操作系统和加速器进行不同的构建。

总结: 对于大多数用户来说,典型的 JAX 安装可能如下所示

  • 仅 CPU (Linux/macOS/Windows)

    pip install -U jax
    
  • GPU (NVIDIA, CUDA 12)

    pip install -U "jax[cuda12]"
    
  • TPU (Google Cloud TPU VM)

    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

支持的平台#

下表显示了所有支持的平台和安装选项。请检查您的设置是否受支持;如果显示“是”“实验性”,则单击相应的链接以了解如何更详细地安装 JAX。

Linux, x86_64

Linux, aarch64

Mac, x86_64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

NVIDIA GPU

不适用

实验性

Google Cloud TPU

不适用

不适用

不适用

不适用

不适用

AMD GPU

实验性

实验性

不适用

Apple GPU

不适用

不适用

实验性

不适用

不适用

Intel GPU

实验性

不适用

不适用

不适用

CPU#

pip 安装:CPU#

目前,JAX 团队为以下操作系统和架构发布 jaxlib wheel

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Intel

  • macOS, 基于 Apple ARM

  • Windows, x86_64 (实验性)

要安装仅 CPU 版本的 JAX,这对于在笔记本电脑上进行本地开发可能很有用,您可以运行

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,如果您的计算机上尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。

其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能会导致 jaxlib 未与 jax 一起安装,尽管 jax 可能成功安装(但在运行时失败)。

NVIDIA GPU#

JAX 支持 SM 版本为 5.2 (Maxwell) 或更高版本的 NVIDIA GPU。请注意,由于 NVIDIA 已在其软件中停止支持 Kepler GPU,因此 JAX 不再支持 Kepler 系列 GPU。

您必须先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。

如果您需要在较旧的驱动程序上使用较新的 CUDA 工具包,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您或许可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

pip 安装:NVIDIA GPU(CUDA,通过 pip 安装,更容易)#

有两种方法可以安装支持 NVIDIA GPU 的 JAX

  • 使用通过 pip wheel 安装的 NVIDIA CUDA 和 cuDNN

  • 使用自安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheel 安装 CUDA 和 cuDNN,因为它更容易!

NVIDIA 仅为 x86_64 和 aarch64 发布了 CUDA pip 软件包;在其他平台上,您必须使用 CUDA 的本地安装。

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,则需要检查以下几项

  • 确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 会覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库是 JAX 请求的那些。重新运行上述安装命令应该可以工作。

pip 安装:NVIDIA GPU(CUDA,本地安装,更难)#

如果您希望使用预安装的 NVIDIA CUDA 副本,则必须先安装 NVIDIA CUDAcuDNN

JAX 仅为 Linux x86_64 和 Linux aarch64 提供预构建的 CUDA 兼容 wheel。其他操作系统和架构的组合是可能的,但需要从源代码构建(有关详细信息,请参阅 从源代码构建)。

您应该使用至少与您的 NVIDIA CUDA 工具包的对应驱动程序版本一样新的 NVIDIA 驱动程序版本。如果您需要在较旧的驱动程序上使用较新的 CUDA 工具包,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您或许可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

JAX 目前提供一个 CUDA wheel 变体

使用以下版本构建

与以下版本兼容

CUDA 12.3

CUDA >=12.1

CUDNN 9.1

CUDNN >=9.1, <10.0

NCCL 2.19

NCCL >=2.18

JAX 会检查您的库的版本,如果它们不够新,则会报告错误。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用此检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅当您执行多 GPU 计算时才需要。

要安装,请运行

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"

这些 pip 安装在 Windows 上不起作用,可能会静默失败;请参阅上面的 表。

您可以使用以下命令查找您的 CUDA 版本

nvcc --version

JAX 使用 LD_LIBRARY_PATH 来查找 CUDA 库,使用 PATH 来查找二进制文件(ptxasnvlink)。请确保这些路径指向正确的 CUDA 安装。

JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 软件包。确保它存在于您的 CUDA 安装中。

如果您在使用预构建的 wheel 时遇到任何错误或问题,请在 GitHub 问题跟踪器上告知 JAX 团队。

NVIDIA GPU Docker 容器#

NVIDIA 提供了 JAX 工具箱容器,这些容器是包含 jax 和一些模型/框架的每日构建版本的前沿容器。

Google Cloud TPU#

pip 安装:Google Cloud TPU#

JAX 为 Google Cloud TPU 提供预构建的 wheel。要安装 JAX 以及相应版本的 jaxliblibtpu,您可以在云 TPU VM 中运行以下命令

pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

对于 Colab (https://colab.research.google.com/) 的用户,请确保您使用的是 TPU v2,而不是较旧的、已弃用的 TPU 运行时。

Mac GPU#

pip 安装#

Apple 提供了一个实验性的 Metal 插件。有关详细信息,请参阅 Apple 的 Metal 上 JAX 文档

注意: Metal 插件有几个注意事项

  • Metal 插件是新的且是实验性的,并且存在许多 已知问题。请在 JAX 问题跟踪器上报告任何问题。

  • Metal 插件目前需要非常特定的 jaxjaxlib 版本。随着插件 API 的成熟,此限制将逐渐放宽。

AMD GPU (Linux)#

JAX 具有实验性的 ROCm 支持。有两种安装 JAX 的方法

Intel GPU#

Intel 为 Intel GPU 硬件提供了一个实验性的 OneAPI 插件:intel-extension-for-openxla。有关更多详细信息和安装说明,请参阅以下两种方法之一

  1. Pip 安装:Intel GPU 上的 JAX 加速

  2. 使用 Intel 的 XLA Docker 容器

请报告与以下内容相关的任何问题

Conda(社区支持)#

Conda 安装#

有一个社区支持的 jax Conda 构建版本。要使用 conda 安装它,只需运行

conda install jax -c conda-forge

如果您在具有 NVIDIA GPU 的计算机上运行此命令,则应安装 CUDA 支持的 jaxlib 软件包。

要确保您安装的 jax 版本确实启用了 CUDA,请运行

conda install "jaxlib=*=*cuda*" jax -c conda-forge

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的计算机上安装 CUDA 构建版本,请按照 conda-forge 网站的技巧和窍门部分中的说明进行操作。

有关更多详细信息,请访问 conda-forgejaxlibjax 存储库。

JAX 每日构建版本安装#

每日构建版本反映的是构建时主 JAX 仓库的状态,可能无法通过完整的测试套件。

与安装 JAX 发布的说明不同,这里我们在命令行中显式地指定了 JAX 的所有包,因此如果存在更新的版本,pip 将会升级它们。

  • 仅限 CPU

pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 12)

pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • NVIDIA GPU (CUDA 12) 传统版

对于单体 CUDA jaxlibs 的历史每日构建版本,请使用以下内容。您很可能不需要这个;不会再构建更多的单体 CUDA jaxlibs,并且现有的版本将在 2024 年 9 月到期。请使用上面的 “CUDA 12” 选项。

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

从源代码构建 JAX#

请参考 从源代码构建

安装旧版 jaxlib wheels#

由于 Python 包索引的存储限制,JAX 团队会定期从 https://pypi.ac.cn/project/jax 的发布版本中删除旧的 jaxlib wheels。这些 wheels 仍然可以通过此处的 URL 直接安装。例如

# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

对于特定的旧版 GPU wheels,请务必使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html