有状态计算#

JAX 转换(例如 jit(), vmap(), grad())要求它们包装的函数是纯函数:即,输出完全依赖于输入的函数,并且没有诸如更新全局状态之类的副作用。您可以在 JAX 尖锐之处:纯函数中找到对此的讨论。

在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可能以多种形式存在。例如:

  • 模型参数,

  • 优化器状态,以及

  • 有状态的层,例如 BatchNorm

本节提供有关如何在 JAX 程序中正确处理状态的一些建议。

一个简单的例子:计数器#

让我们从一个简单的有状态程序开始:一个计数器。

import jax
import jax.numpy as jnp

class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())
1
2
3

计数器的 n 属性在连续调用 count 之间维护计数器的状态。它作为调用 count 的副作用被修改。

假设我们想要快速计数,因此我们 JIT 编译 count 方法。(在这个例子中,由于很多原因,这实际上并不能提高速度,但请将其视为 JIT 编译模型参数更新的玩具模型,其中 jit() 带来了巨大的差异)。

counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  print(fast_count())
1
1
1

糟糕!我们的计数器不起作用了。这是因为 count 中的这一行

self.n += 1

count 中涉及到一个副作用:它就地修改了输入计数器,因此这个函数不受 jit 的支持。这种副作用只在函数第一次被跟踪时执行一次,后续调用不会重复该副作用。那么,我们该如何修复它呢?

解决方案:显式状态#

我们计数器的问题之一是返回值不依赖于参数,这意味着一个常量被“烘焙”到编译后的输出中。但这不应该是一个常量——它应该依赖于状态。那么,为什么不把状态变成一个参数呢?

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value)
1
2
3

在这个新版本的 Counter 中,我们将 n 移动为 count 的参数,并添加了另一个返回值,表示新的、更新后的状态。要使用这个计数器,我们现在需要显式地跟踪状态。但作为回报,我们现在可以安全地使用 jax.jit 这个计数器了

state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)
1
2
3

一个通用策略#

我们可以将相同的过程应用于任何有状态的方法,将其转换为无状态的方法。我们采用如下形式的类:

class StatefulClass

  state: State

  def stateful_method(*args, **kwargs) -> Output:

并将其转换为如下形式的类:

class StatelessClass

  def stateless_method(state: State, *args, **kwargs) -> (Output, State):

这是一种常见的 函数式编程 模式,本质上,这是所有 JAX 程序中处理状态的方式。

请注意,一旦我们以这种方式重写它,对类的需求就变得不太明显了。我们可以只保留 stateless_method,因为该类不再执行任何工作。这是因为,就像我们刚刚应用的策略一样,面向对象编程 (OOP) 是一种帮助程序员理解程序状态的方式。

在我们的例子中,CounterV2 类只不过是一个命名空间,它将所有使用 CounterState 的函数集中到一个位置。留给读者练习:你认为把它保留为一个类有意义吗?

顺便说一句,你已经在 JAX 伪随机性 API jax.random 中看到了这个策略的示例,该示例在 伪随机数 部分中显示。与使用隐式更新的有状态类管理随机状态的 Numpy 不同,JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。

简单的示例:线性回归#

让我们将这个策略应用于一个简单的机器学习模型:通过梯度下降进行线性回归。

在这里,我们只处理一种状态:模型参数。但通常,你会看到各种状态被传入和传出 JAX 函数,例如优化器状态、批量归一化的层统计信息等等。

要仔细查看的函数是 update

from typing import NamedTuple

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray


def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)


def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * x + params.bias
  return jnp.mean((pred - y) ** 2)


LEARNING_RATE = 0.005

@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
  """Performs one SGD update step on params using the given data."""
  grad = jax.grad(loss)(params, x, y)

  # If we were using Adam or another stateful optimizer,
  # we would also do something like
  #
  #   updates, new_optimizer_state = optimizer(grad, optimizer_state)
  # 
  # and then use `updates` instead of `grad` to actually update the params.
  # (And we'd include `new_optimizer_state` in the output, naturally.)

  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grad)

  return new_params

请注意,我们手动将参数输入和输出更新函数。

import matplotlib.pyplot as plt

rng = jax.random.key(42)

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise

# Fit regression
params = init(rng)
for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_2788/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  new_params = jax.tree_map(
_images/33c94c378fa1781345bf34542361a470f23e5450911aaae933f5b30e325c0ccb.png

更进一步#

上述策略是任何 JAX 程序在使用 jitvmapgrad 等转换时必须处理状态的方式。

如果你处理的是两个参数,手动处理参数似乎没问题,但如果它是一个有几十层的神经网络呢?你可能已经开始担心两件事了:

  1. 我们是否应该手动初始化它们,本质上是重复我们在前向传递定义中已经编写的内容?

  2. 我们是否应该手动传递所有这些东西?

细节可能很难处理,但是有一些库的例子可以为你处理这个问题。请参阅 JAX 神经网络库 以获取一些示例。