伪随机数#
在本节中,我们将重点关注 jax.random
和伪随机数生成 (PRNG);也就是说,算法生成数字序列的过程,这些数字序列的属性近似于从适当分布中采样的随机数序列的属性。
PRNG 生成的序列并非真正随机,因为它们实际上是由其初始值确定的,该初始值通常称为 seed
,并且随机采样的每个步骤都是某个 state
的确定性函数,该状态从一个样本传递到下一个样本。
伪随机数生成是任何机器学习或科学计算框架的重要组成部分。通常,JAX 努力与 NumPy 保持兼容,但伪随机数生成是一个显著的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方面所采用的方法之间的区别,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 原生支持伪随机数生成,通过 numpy.random
模块实现。在 NumPy 中,伪随机数生成基于全局 state
,可以使用 numpy.random.seed()
将其设置为确定的初始条件。
import numpy as np
np.random.seed(0)
您可以使用以下命令检查状态内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
每次调用随机函数时都会更新 state
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您在单个函数调用中采样单个数字或整个数字向量。例如,您可以通过以下方式从均匀分布中采样一个包含 3 个标量的向量:
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供了 *顺序等价保证*,这意味着依次单独采样 N 个数字或采样一个包含 N 个数字的向量会产生相同的伪随机序列。
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成与 NumPy 的有重要区别,因为 NumPy 的 PRNG 设计难以同时保证许多理想的特性。具体来说,在 JAX 中,我们希望 PRNG 生成是
可复现的,
可并行的,
可向量化的。
我们将在后面讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的影响。考虑以下代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
函数 foo
对从均匀分布中采样的两个标量求和。
只有在我们假设 bar()
和 baz()
的执行顺序可预测的情况下,此代码的输出才能满足要求 #1。在 NumPy 中这不是问题,NumPy 始终按照 Python 解释器定义的顺序执行代码。然而,在 JAX 中,这个问题更加突出:为了提高执行效率,我们希望 JIT 编译器能够自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程都需要同步全局状态,这会降低执行效率。
显式随机状态#
为了避免此问题,JAX 避免使用隐式全局随机状态,而是通过随机 key
显式跟踪状态。
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用由 jax.random.key()
生成的新的类型化 PRNG 密钥,而不是由 jax.random.PRNGKey()
生成的旧的原始 PRNG 密钥。有关详细信息,请参阅 JEP 9263:类型化密钥和可插拔 RNG。
密钥是一个具有特殊 dtype 的数组,对应于正在使用的特定 PRNG 实现;在默认实现中,每个密钥都由一对 uint32
值支持。
密钥实际上是 NumPy 的隐藏状态对象的替代品,但我们将其显式传递给 jax.random()
函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象馈送到随机函数始终会导致生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.18471177
-0.18471177
即使使用不同的 random
API,重复使用相同的密钥也可能导致输出相关,这通常是不希望的。
经验法则是:永远不要重复使用密钥(除非您希望输出相同)。
为了生成不同的独立样本,您必须在将其传递给随机函数之前显式 split()
密钥
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592
(此处不需要调用 del
,但我们这样做是为了强调一旦密钥被消耗就不应该重复使用。)
jax.random.split()
是一个确定性函数,它将一个 key
转换为多个独立的(在伪随机性意义上)密钥。我们将其中一个输出保留为 new_key
,并且可以安全地使用唯一的额外密钥(称为 subkey
)作为随机函数的输入,然后永远丢弃它。如果您想从正态分布中获取另一个样本,则需要再次拆分 key
,依此类推:关键是您永远不要两次使用相同的密钥。
我们称 split(key)
的哪个输出部分为 key
,哪个输出部分为 subkey
都无关紧要。它们都是具有相同状态的独立密钥。密钥/子密钥命名约定是一种典型的用法模式,有助于跟踪密钥的消耗方式:子密钥注定要被随机函数立即消耗,而密钥则保留以供以后生成更多随机性。
通常,上面的示例会简写为
key, subkey = random.split(key)
这会自动丢弃旧密钥。值得注意的是,split()
可以创建任意数量的密钥,而不仅仅是 2 个。
key, *forty_two_subkeys = random.split(key, num=43)
缺乏顺序等价性#
NumPy 和 JAX 的随机模块之间的另一个区别与上面提到的顺序等价保证有关。
与 NumPy 一样,JAX 的随机模块也允许采样数字向量。但是,JAX 不提供顺序等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上面要求 #3)。
在下面的示例中,使用三个子密钥分别从正态分布中采样 3 个值的结果与使用单个密钥并指定 shape=(3,)
的结果不同。
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838832 0.10796154 -1.2226542 ]
all at once: [ 0.18693547 -1.2806505 -1.5593132 ]
缺乏顺序等价性使我们能够更有效地编写代码;例如,我们可以使用 jax.vmap()
以向量化的方式计算相同的结果,而不是像上面那样通过顺序循环生成 sequence
。
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [-0.04838832 0.10796154 -1.2226542 ]
后续步骤#
有关 JAX 随机数的更多信息,请参阅 jax.random
模块的文档。如果您有兴趣了解 JAX 随机数生成器设计的详细信息,请参阅 JAX PRNG 设计。