有状态计算#

JAX 转换,例如 jit()vmap()grad(),要求它们包装的函数是纯函数:也就是说,函数的输出完全依赖于输入,并且没有任何副作用,例如更新全局状态。你可以在 JAX 核心要点:纯函数 中找到对此的讨论。

在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可能以多种形式存在。例如

  • 模型参数,

  • 优化器状态,以及

  • 有状态层,例如 BatchNorm

本节提供了一些关于如何在 JAX 程序中正确处理状态的建议。

一个简单的例子:计数器#

让我们从一个简单的有状态程序开始:一个计数器。

import jax
import jax.numpy as jnp

class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())
1
2
3

计数器的 n 属性在连续调用 count 时保持计数器的状态。它是在调用 count 时作为副作用修改的。

假设我们想要快速计数,因此我们对 count 方法进行 JIT 编译。(在本例中,由于多种原因,这实际上并不会提高速度,但请将其视为 JIT 编译模型参数更新的玩具模型,其中 jit() 会带来巨大差异)。

counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  print(fast_count())
1
1
1

糟糕!我们的计数器无法正常工作。这是因为 count 中的代码行

self.n += 1

涉及副作用:它就地修改了输入计数器,因此此函数不受 jit 支持。此类副作用仅在函数首次追踪时执行一次,后续调用不会重复该副作用。那么,我们该如何解决呢?

解决方案:显式状态#

我们计数器出现问题的部分原因是返回的值不依赖于参数,这意味着常量被“烘焙”到已编译的输出中。但它不应该是一个常量 - 它应该取决于状态。那么,为什么我们不将状态作为参数呢?

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value)
1
2
3

在这个新版本的 Counter 中,我们将 n 移到 count 的参数中,并添加了另一个返回值来表示新的、更新后的状态。要使用此计数器,我们现在需要显式跟踪状态。但作为回报,我们现在可以安全地 jax.jit 此计数器

state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)
1
2
3

通用策略#

我们可以将相同的过程应用于任何有状态方法,将其转换为无状态方法。我们采用了一种形式为

class StatefulClass

  state: State

  def stateful_method(*args, **kwargs) -> Output:

的类,并将其转换为形式为

class StatelessClass

  def stateless_method(state: State, *args, **kwargs) -> (Output, State):

的类。这是一种常见的 函数式编程 模式,从本质上讲,这是 JAX 程序中处理状态的方式。

请注意,一旦我们以这种方式重写了它,对类的需求变得不那么明显了。我们可以直接保留 stateless_method,因为该类不再执行任何工作。这是因为,就像我们刚刚应用的策略一样,面向对象编程 (OOP) 是一种帮助程序员理解程序状态的方式。

在我们的例子中,CounterV2 类只不过是一个命名空间,它将所有使用 CounterState 的函数整合到一个位置。练习:您认为将其保留为一个类是否有意义?

顺便说一下,您已经在 JAX 伪随机性 API,jax.random 中看到了此策略的示例,如 伪随机数 部分所示。与使用隐式更新的有状态类来管理随机状态的 NumPy 不同,JAX 要求程序员直接使用随机生成器状态 - PRNG 密钥。

简单的示例:线性回归#

让我们将此策略应用于一个简单的机器学习模型:通过梯度下降进行线性回归。

在这里,我们只处理一种状态:模型参数。但通常情况下,您会看到许多种状态被传入和传出 JAX 函数,例如优化器状态、用于批归一化的层统计信息等等。

需要仔细查看的函数是 update

from typing import NamedTuple

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray


def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)


def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * x + params.bias
  return jnp.mean((pred - y) ** 2)


LEARNING_RATE = 0.005

@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
  """Performs one SGD update step on params using the given data."""
  grad = jax.grad(loss)(params, x, y)

  # If we were using Adam or another stateful optimizer,
  # we would also do something like
  #
  #   updates, new_optimizer_state = optimizer(grad, optimizer_state)
  # 
  # and then use `updates` instead of `grad` to actually update the params.
  # (And we'd include `new_optimizer_state` in the output, naturally.)

  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grad)

  return new_params

请注意,我们将参数手动传入和传出更新函数。

import matplotlib.pyplot as plt

rng = jax.random.key(42)

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise

# Fit regression
params = init(rng)
for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_3123/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  new_params = jax.tree_map(
_images/5dea2f929fb59e89273132b2695583526ad0d63ce93cd139532e8cc5bc433783.png

更进一步#

上面描述的策略是任何 JAX 程序在使用 jitvmapgrad 等转换时必须处理状态的方式。

如果您处理的是两个参数,手动处理参数似乎没问题,但如果这是一个具有数十个层的网络呢?您可能已经开始担心两件事

  1. 我们是否应该手动初始化所有这些参数,本质上重复我们在前向传播定义中已经编写的内容?

  2. 我们是否应该手动将所有这些东西传递到各个函数之间?

细节可能难以处理,但有一些库可以为您处理这些细节。请参阅 JAX 神经网络库 了解一些示例。