术语表

术语表#

数组#

JAX 中 numpy.ndarray 的类似物。参见 jax.Array

CPU#

中央处理器的简称,CPU 是大多数计算机中提供的标准计算架构。JAX 可以在 CPU 上运行计算,但在 GPUTPU 上通常可以实现更好的性能。

设备#

用于指代 JAX 用于执行计算的 CPUGPUTPU 的通用名称。

前向模式自动微分#

参见 JVP

函数式编程#

一种编程范式,其中程序通过应用和组合 纯函数 来定义。JAX 旨在与函数式程序一起使用。

GPU#

GPU 是 Graphical Processing Unit 的缩写,最初专门用于屏幕上图像渲染相关的操作,但现在用途更加广泛。JAX 能够针对 GPU 对数组进行快速操作(另请参见 CPUTPU)。

jaxpr#

jaxpr 是 JAX 表达式 的缩写,是 JAX 生成的计算的中间表示,并转发到 XLA 进行编译和执行。有关更多讨论和示例,请参见 理解 Jaxprs

JIT#

JIT 是 Just In Time 编译的缩写,在 JAX 中通常指的是将数组操作编译到 XLA,最常使用 jax.jit() 完成。

JVP#

JVP 是 Jacobian 向量积 的缩写,有时也称为 前向模式 自动微分。有关更多详细信息,请参见 雅可比向量积 (JVP,即前向模式自动微分)。在 JAX 中,JVP 是一种通过 jax.jvp() 实现的 变换。另请参见 VJP

基本运算#

基本运算是在 JAX 程序中使用的计算的基本单元。 jax.lax 中的大多数函数都表示单个基本运算。在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个基本运算。

纯函数#

纯函数是指其输出仅基于其输入且没有副作用的函数。JAX 的 变换 模型旨在与纯函数一起使用。另请参见 函数式编程

pytree#

pytree 是一种抽象,它允许 JAX 以统一的方式处理元组、列表、字典和其他更通用的数组值容器。有关更详细的讨论,请参阅 使用 pytree

反向模式自动微分#

请参见 VJP

SPMD#

SPMD 是 Single Program Multi Data 的缩写,它指的是一种并行计算技术,其中相同的计算(例如,神经网络的前向传播)在不同的输入数据(例如,批次中的不同输入)上并行地在不同的设备(例如,多个 TPU)上运行。 jax.pmap() 是一个实现 SPMD 并行性的 JAX 变换

静态#

JIT 编译中,未跟踪的值(请参见 Tracer)。有时也指编译时对静态值的计算。

TPU#

TPU 是 Tensor Processing Unit 的缩写,是专门为深度学习应用中使用的 N 维张量的快速操作而设计的芯片。JAX 能够针对 TPU 对数组进行快速操作(另请参见 CPUGPU)。

Tracer#

一个用作 JAX Array 替身的对象,以确定 Python 函数执行的操作序列。在内部,JAX 通过 jax.core.Tracer 类实现此功能。

变换#

高阶函数:即,一个将函数作为输入并输出转换后的函数的函数。JAX 中的示例包括 jax.jit()jax.vmap()jax.grad()

VJP#

VJP 是 向量雅可比积 的缩写,有时也称为 反向模式 自动微分。有关更多详细信息,请参见 向量雅可比积 (VJP,即反向模式自动微分)。在 JAX 中,VJP 是一种通过 jax.vjp() 实现的 变换。另请参见 JVP

XLA#

XLA 是 Accelerated Linear Algebra 的缩写,是线性代数运算的特定领域编译器,是 JIT 编译的 JAX 代码的主要后端。请参见 https://tensorflowcn.cn/xla/

弱类型#

一种 JAX 数据类型,其类型提升语义与 Python 标量相同;请参见 JAX 中的弱类型值