关键概念#

本节简要介绍 JAX 包的一些关键概念。

JAX 数组 (jax.Array)#

JAX 中的默认数组实现是 jax.Array。在许多方面,它类似于您可能熟悉的 NumPy 包中的 numpy.ndarray 类型,但它有一些重要的区别。

数组创建#

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了熟悉的 NumPy 风格的数组构建功能,如 jax.numpy.zeros(), jax.numpy.linspace(), jax.numpy.arange() 等。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

如果在你的代码中使用 Python 类型注解,jax.Array 是 JAX 数组对象的适当注解(更多讨论请参见 jax.typing)。

数组设备和分片#

JAX 数组对象有一个 devices 方法,可以让你检查数组的内容存储在何处。在最简单的情况下,这将是单个 CPU 设备

x.devices()
{CpuDevice(id=0)}

通常,一个数组可能会以一种可以通过 sharding 属性检查的方式跨多个设备进行分片

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

这里数组位于单个设备上,但通常 JAX 数组可以跨多个设备甚至多个主机进行分片。要了解更多关于分片数组和并行计算的信息,请参阅 并行编程简介

转换#

除了用于操作数组的函数之外,JAX 还包括许多对 JAX 函数进行操作的转换。这些包括:

以及其他几个转换。转换接受一个函数作为参数,并返回一个新的转换后的函数。例如,以下是如何 JIT 编译一个简单的 SELU 函数的方法:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

为了方便起见,你经常会看到使用 Python 的装饰器语法来应用转换

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

jit(), vmap(), grad() 等转换是有效使用 JAX 的关键,我们将在后面的章节中详细介绍它们。

追踪#

转换背后的魔力是 追踪器 的概念。追踪器是数组对象的抽象占位符,并被传递给 JAX 函数,以便提取该函数编码的操作序列。

你可以在转换后的 JAX 代码中打印任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>

打印的值不是数组 x,而是 Tracer 实例,它表示 x 的基本属性,例如它的 shapedtype。通过使用跟踪值执行函数,JAX 可以确定函数编码的操作序列,然后再实际执行这些操作:像 jit(), vmap(), 和 grad() 这样的转换可以将输入操作的这个序列映射到转换后的操作序列。

Jaxprs#

JAX 有自己的操作序列的中间表示,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是函数式程序的简单表示,包含一系列原始操作。

例如,考虑上面定义的 selu 函数

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

我们可以使用 jax.make_jaxpr() 实用程序,给定一个特定的输入,将此函数转换为 jaxpr:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) }

将其与 Python 函数定义进行比较,我们看到它编码了该函数所表示的精确操作序列。我们将在 JAX 内部结构:jaxpr 语言 中更深入地介绍 jaxpr。

Pytrees#

JAX 函数和转换基本上是对数组进行操作,但在实践中,编写可以处理数组集合的代码非常方便:例如,神经网络可能会将其参数组织成一个带有有意义键的数组字典。为了避免逐个处理这种结构,JAX 依赖于 pytree 抽象,以统一的方式处理此类集合。

以下是一些可以被视为 pytree 的对象示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX 有许多用于处理 PyTree 的通用实用程序;例如,函数 jax.tree.map() 可用于将函数映射到树中的每个叶子节点,而 jax.tree.reduce() 可用于在树的叶子节点上应用归约。

你可以在 使用 pytrees 教程中了解更多信息。

伪随机数#

一般来说,JAX 力求与 NumPy 兼容,但伪随机数生成是一个明显的例外。NumPy 支持一种基于全局 state 的伪随机数生成方法,可以使用 numpy.random.seed() 设置。全局随机状态与 JAX 的计算模型交互不良,并且难以在不同的线程、进程和设备之间强制实现可重复性。JAX 而是通过随机 key 显式跟踪状态

from jax import random

key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]

这个键实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random() 函数。重要的是,随机函数会消耗键,但不会修改它:将相同的键对象传递给随机函数将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key))
0.81039715
0.81039715

经验法则是:永远不要重用键(除非你想要相同的输出)。

为了生成不同且独立的样本,你必须在将键传递给随机函数之前显式 split() 该键

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: 0.19468608498573303
draw 1: 0.5202823877334595
draw 2: -2.072833299636841

请注意,此代码是线程安全的,因为局部随机状态消除了涉及全局状态的可能竞争条件。jax.random.split() 是一个确定性函数,它将一个 key 转换为几个独立的(在伪随机性意义上)键。

有关 JAX 中伪随机数的更多信息,请参阅 伪随机数 教程。