安装#

使用 JAX 需要安装两个包: jax,它是纯 Python 且跨平台的,以及 jaxlib,它包含编译后的二进制文件,并且需要为不同的操作系统和加速器构建不同的版本。

总结: 对于大多数用户来说,典型的 JAX 安装可能看起来像这样

  • 仅限 CPU(Linux/macOS/Windows)

    pip install -U jax
    
  • GPU(NVIDIA,CUDA 12)

    pip install -U "jax[cuda12]"
    
  • TPU(Google Cloud TPU VM)

    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

支持的平台#

下表显示了所有支持的平台和安装选项。检查你的设置是否受支持;如果它显示“是”“实验性”,请单击相应的链接以详细了解如何在 JAX 中安装 JAX。

Linux,x86_64

Linux,aarch64

macOS,Intel x86_64,AMD GPU

macOS,Apple Silicon,基于 ARM

Windows,x86_64

Windows WSL2,x86_64

CPU

NVIDIA GPU

n/a

实验性

Google Cloud TPU

n/a

n/a

n/a

n/a

n/a

AMD GPU

实验性

n/a

Apple GPU

n/a

实验性

实验性

n/a

n/a

CPU#

pip 安装: CPU#

目前,JAX 团队为以下操作系统和体系结构发布了 jaxlib 轮子

  • Linux,x86_64

  • Linux,aarch64

  • macOS,Intel

  • macOS,Apple 基于 ARM

  • Windows,x86_64(实验性

要安装仅限 CPU 的 JAX 版本,这可能对在笔记本电脑上进行本地开发很有用,你可以运行

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,你可能还需要安装 Microsoft Visual Studio 2019 Redistributable,如果它尚未安装在你的机器上。

其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上进行 pip 安装可能会导致 jaxlib 未与 jax 一起安装,尽管 jax 可能已成功安装(但在运行时失败)。

NVIDIA GPU#

JAX 支持 SM 版本 5.2(Maxwell)或更新的 NVIDIA GPU。请注意,由于 NVIDIA 已在其软件中停止支持 Kepler GPU,因此 JAX 不再支持 Kepler 系列 GPU。

您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。

如果您需要在较旧的驱动程序上使用较新的 CUDA 工具包(例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

pip 安装:NVIDIA GPU(CUDA,通过 pip 安装,更轻松)#

有两种方法可以安装支持 NVIDIA GPU 的 JAX

  • 使用从 pip 轮安装的 NVIDIA CUDA 和 cuDNN

  • 使用自安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip 轮安装 CUDA 和 cuDNN,因为它非常容易!

NVIDIA 仅为 x86_64 和 aarch64 发布了 CUDA pip 软件包;在其他平台上,您必须使用本地安装的 CUDA。

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,则需要检查以下几个事项

  • 确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 会覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库是 JAX 所要求的库。重新运行上面的安装命令应该可以解决问题。

pip 安装:NVIDIA GPU(CUDA,本地安装,更难)#

如果您更喜欢使用预安装的 NVIDIA CUDA 副本,则必须首先安装 NVIDIA CUDAcuDNN

JAX 为 **仅限 Linux x86_64 和 Linux aarch64** 提供预构建的 CUDA 兼容轮。其他操作系统和架构的组合也是可能的,但需要从源代码构建(请参阅 从源代码构建 了解更多信息)。

您应该使用至少与您的 NVIDIA CUDA 工具包的对应驱动程序版本 一样新的 NVIDIA 驱动程序版本。如果您需要在较旧的驱动程序上使用较新的 CUDA 工具包(例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

JAX 目前提供一个 CUDA 轮变体

使用

兼容

CUDA 12.3

CUDA >=12.1

CUDNN 9.1

CUDNN >=9.1,<10.0

NCCL 2.19

NCCL >=2.18

JAX 会检查您的库版本,如果版本不够新,则会报错。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅在您进行多 GPU 计算时才需要。

要安装,请运行

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"

这些 pip 安装无法与 Windows 配合使用,并且可能会静默失败;请参阅表格 以上

您可以使用以下命令找到您的 CUDA 版本

nvcc --version

JAX 使用 LD_LIBRARY_PATH 查找 CUDA 库,使用 PATH 查找二进制文件(ptxasnvlink)。请确保这些路径指向正确的 CUDA 安装。

JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 软件包。确保它存在于您的 CUDA 安装中。

如果您在使用预构建的轮时遇到任何错误或问题,请在 GitHub 问题跟踪器 上告知 JAX 团队。

NVIDIA GPU Docker 容器#

NVIDIA 提供 JAX 工具箱 容器,它们是包含 jax 和一些模型/框架的夜间版本的尖端容器。

Google Cloud TPU#

pip 安装:Google Cloud TPU#

JAX 为 Google Cloud TPU 提供预构建的轮。要安装 JAX 以及 jaxliblibtpu 的适当版本,您可以在云 TPU VM 中运行以下命令

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

对于 Colab(https://colab.research.google.com/)的用户,请确保您使用的是 *TPU v2* 而不是较旧的已弃用的 TPU 运行时。

Apple Silicon GPU(基于 ARM)#

pip 安装:Apple 基于 ARM 的 Silicon GPU#

Apple 为 Apple 基于 ARM 的 GPU 硬件提供了一个实验性的 Metal 插件。有关详细信息,请参阅 Apple 的 JAX on Metal 文档

注意: Metal 插件有几个注意事项

  • Metal 插件是新的且实验性的,并且有一些 已知问题。请在 JAX 问题跟踪器中报告任何问题。

  • Metal 插件目前需要非常具体的 jaxjaxlib 版本。随着插件 API 的成熟,此限制将随着时间的推移而放宽。

AMD GPU#

JAX 具有实验性的 ROCm 支持。有两种方法可以安装 JAX

Conda(社区支持)#

Conda 安装#

有一个社区支持的 jax 的 Conda 构建。要使用 conda 安装它,只需运行

conda install jax -c conda-forge

要在配备 NVIDIA GPU 的机器上安装它,请运行

conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

请注意,由 conda-forge 分发的 cudatoolkit 缺少 JAX 所需的 ptxas。因此,您必须从 nvidia 频道安装 cuda-nvcc 软件包,或者单独在您的机器上安装 CUDA,以便 ptxas 位于您的路径中。上面的通道顺序很重要(conda-forgenvidia 之前)。

如果您想覆盖 JAX 使用的 CUDA 版本,或者要在没有 GPU 的机器上安装 CUDA 构建,请按照 conda-forge 网站的 提示和技巧 部分的说明进行操作。

转到 conda-forgejaxlibjax 存储库以获取更多详细信息。

JAX 夜间安装#

夜间版本反映了 JAX 主存储库在构建时的状态,可能无法通过完整的测试套件。

与安装 JAX 版本的说明不同,这里我们在命令行中明确地命名了 JAX 的所有软件包,因此如果存在更新的版本,pip 将对其进行升级。

  • 仅限 CPU

pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU(CUDA 12)

pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • NVIDIA GPU(CUDA 12)旧版

对于历史版本的单片 CUDA jaxlib,请使用以下命令。您很可能不需要这样做;不会再构建单片 CUDA jaxlib,现有的 jaxlib 将在 2024 年 9 月前过期。使用上面的“CUDA 12”选项。

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

从源代码构建 JAX#

请参阅 从源代码构建

安装较旧的 jaxlib#

由于 Python 软件包索引上的存储限制,JAX 团队会定期从 http://pypi.org/project/jax 上的版本中删除较旧的 jaxlib 轮。这些仍然可以通过这里的 URL 直接安装。例如

# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

对于特定版本的较旧 GPU 轮,请确保使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html