有状态计算#
JAX 转换(例如 jit()
, vmap()
, grad()
)要求它们包装的函数是纯函数:即,输出完全依赖于输入的函数,并且没有诸如更新全局状态之类的副作用。您可以在 JAX 尖锐之处:纯函数中找到对此的讨论。
在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可能以多种形式存在。例如:
模型参数,
优化器状态,以及
有状态的层,例如 BatchNorm。
本节提供有关如何在 JAX 程序中正确处理状态的一些建议。
一个简单的例子:计数器#
让我们从一个简单的有状态程序开始:一个计数器。
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
1
2
3
计数器的 n
属性在连续调用 count
之间维护计数器的状态。它作为调用 count
的副作用被修改。
假设我们想要快速计数,因此我们 JIT 编译 count
方法。(在这个例子中,由于很多原因,这实际上并不能提高速度,但请将其视为 JIT 编译模型参数更新的玩具模型,其中 jit()
带来了巨大的差异)。
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1
1
1
糟糕!我们的计数器不起作用了。这是因为 count
中的这一行
self.n += 1
在 count
中涉及到一个副作用:它就地修改了输入计数器,因此这个函数不受 jit
的支持。这种副作用只在函数第一次被跟踪时执行一次,后续调用不会重复该副作用。那么,我们该如何修复它呢?
解决方案:显式状态#
我们计数器的问题之一是返回值不依赖于参数,这意味着一个常量被“烘焙”到编译后的输出中。但这不应该是一个常量——它应该依赖于状态。那么,为什么不把状态变成一个参数呢?
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
1
2
3
在这个新版本的 Counter
中,我们将 n
移动为 count
的参数,并添加了另一个返回值,表示新的、更新后的状态。要使用这个计数器,我们现在需要显式地跟踪状态。但作为回报,我们现在可以安全地使用 jax.jit
这个计数器了
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
1
2
3
一个通用策略#
我们可以将相同的过程应用于任何有状态的方法,将其转换为无状态的方法。我们采用如下形式的类:
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
并将其转换为如下形式的类:
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
这是一种常见的 函数式编程 模式,本质上,这是所有 JAX 程序中处理状态的方式。
请注意,一旦我们以这种方式重写它,对类的需求就变得不太明显了。我们可以只保留 stateless_method
,因为该类不再执行任何工作。这是因为,就像我们刚刚应用的策略一样,面向对象编程 (OOP) 是一种帮助程序员理解程序状态的方式。
在我们的例子中,CounterV2
类只不过是一个命名空间,它将所有使用 CounterState
的函数集中到一个位置。留给读者练习:你认为把它保留为一个类有意义吗?
顺便说一句,你已经在 JAX 伪随机性 API jax.random
中看到了这个策略的示例,该示例在 伪随机数 部分中显示。与使用隐式更新的有状态类管理随机状态的 Numpy 不同,JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。
简单的示例:线性回归#
让我们将这个策略应用于一个简单的机器学习模型:通过梯度下降进行线性回归。
在这里,我们只处理一种状态:模型参数。但通常,你会看到各种状态被传入和传出 JAX 函数,例如优化器状态、批量归一化的层统计信息等等。
要仔细查看的函数是 update
。
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
#
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
#
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
请注意,我们手动将参数输入和输出更新函数。
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_2788/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
new_params = jax.tree_map(
更进一步#
上述策略是任何 JAX 程序在使用 jit
、vmap
、grad
等转换时必须处理状态的方式。
如果你处理的是两个参数,手动处理参数似乎没问题,但如果它是一个有几十层的神经网络呢?你可能已经开始担心两件事了:
我们是否应该手动初始化它们,本质上是重复我们在前向传递定义中已经编写的内容?
我们是否应该手动传递所有这些东西?
细节可能很难处理,但是有一些库的例子可以为你处理这个问题。请参阅 JAX 神经网络库 以获取一些示例。