伪随机数#
如果所有因
rand()
函数不良而导致结果存疑的科学论文都从图书馆书架上消失,那么每个书架上都会出现一个大约拳头大小的空隙。 - Numerical Recipes
在本节中,我们将重点介绍 jax.random
和伪随机数生成(PRNG);也就是说,通过算法生成数字序列的过程,这些数字序列的属性近似于从适当分布中采样的随机数序列的属性。
PRNG 生成的序列不是真正的随机数,因为它们实际上由其初始值(通常称为 seed
,即种子)决定,并且随机采样的每个步骤都是某个 state
(即状态)的确定性函数,该状态从一个样本传递到下一个样本。
伪随机数生成是任何机器学习或科学计算框架的重要组成部分。通常,JAX 力求与 NumPy 兼容,但伪随机数生成是一个明显的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方面所采用的方法之间的差异,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 通过 numpy.random
模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于全局 state
,可以使用 numpy.random.seed()
将其设置为确定性的初始条件。
import numpy as np
np.random.seed(0)
重复调用 NumPy 的有状态伪随机数生成器 (PRNG) 会改变全局状态并产生伪随机数流
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439
在底层,NumPy 使用 梅森旋转算法 PRNG 来支持其伪随机函数。该 PRNG 的周期为 \(2^{19937}-1\),并且在任何时候都可以用 624 个 32 位无符号整数和一个指示此“熵”已使用多少的位置来描述。
您可以使用以下命令检查状态的内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
state
通过每次调用随机函数进行更新
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您在单个函数调用中采样单个数字或整个数字向量。例如,您可以通过执行以下操作从均匀分布中采样 3 个标量向量:
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供顺序等价保证,这意味着连续单独采样 N 个数字或采样 N 个数字的向量会产生相同的伪随机序列
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成与 NumPy 的随机数生成在重要方面有所不同,因为 NumPy 的 PRNG 设计很难同时保证许多期望的属性。具体来说,在 JAX 中,我们希望 PRNG 生成是:
可重现的,
可并行化的,
可向量化的。
我们将在下面讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的含义。考虑以下代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
函数 foo
对从均匀分布中采样的两个标量求和。
只有假设 bar()
和 baz()
的执行顺序可预测,此代码的输出才能满足要求 #1。这在 NumPy 中不是问题,NumPy 始终按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这更成问题:为了高效执行,我们希望 JIT 编译器可以自由地重新排序、删除和融合我们定义的函数中的各种操作。此外,当在多设备环境中执行时,每个进程同步全局状态的需求会妨碍执行效率。
显式随机状态#
为了避免这些问题,JAX 避免了隐式全局随机状态,而是通过随机 key
显式跟踪状态
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用由 jax.random.key()
生成的新式类型化 PRNG 密钥,而不是由 jax.random.PRNGKey()
生成的旧式原始 PRNG 密钥。有关详细信息,请参阅 JEP 9263:类型化密钥和可插拔 RNG。
密钥是一个数组,其具有与正在使用的特定 PRNG 实现相对应的特殊数据类型;在默认实现中,每个密钥都由一对 uint32
值支持。
该密钥实际上是 NumPy 隐藏状态对象的替代品,但我们会将其显式传递给 jax.random()
函数。重要的是,随机函数会使用密钥,但不会修改它:将相同的密钥对象提供给随机函数将始终生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.18471177
-0.18471177
重复使用同一个密钥,即使使用不同的 random
API,也可能导致输出相关,这通常是不希望的。
经验法则是:永远不要重复使用密钥(除非您想要相同的输出)。
JAX 使用现代的 基于 Threefry 计数器的可拆分 PRNG。也就是说,它的设计允许我们将 PRNG 状态分支到新的 PRNG 中,以用于并行随机生成。为了生成不同且独立的样本,您必须在将其传递给随机函数之前显式 split()
密钥
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592
(此处不需要调用 del
,但我们这样做是为了强调一旦使用密钥,就不应重复使用。)
jax.random.split()
是一个确定性函数,它将一个 key
转换为多个独立的(在伪随机意义上)密钥。我们将其中一个输出保留为 new_key
,并且可以安全地将唯一的额外密钥(称为 subkey
)作为随机函数的输入,然后永远丢弃它。如果要从正态分布中获取另一个样本,则需要再次拆分 key
,依此类推:关键是永远不要重复使用同一个密钥。
我们调用 split(key)
输出的哪一部分为 key
,哪一部分为 subkey
并不重要。它们都是具有同等地位的独立密钥。密钥/子密钥命名约定是一种典型的使用模式,有助于跟踪密钥的使用方式:子密钥注定要被随机函数立即使用,而密钥则被保留以便稍后生成更多随机性。
通常,上面的示例可以简洁地写为:
key, subkey = random.split(key)
这会自动丢弃旧密钥。值得注意的是,split()
可以创建您需要的任意多个密钥,而不仅仅是 2 个
key, *forty_two_subkeys = random.split(key, num=43)
缺乏顺序等价性#
NumPy 和 JAX 的随机模块之间的另一个区别与上述顺序等价保证有关。
与 NumPy 一样,JAX 的随机模块也允许对数字向量进行采样。但是,JAX 不提供顺序等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上面的要求 #3)。
在下面的示例中,使用三个子密钥单独从正态分布中采样 3 个值与使用单个密钥并指定 shape=(3,)
得到的结果不同
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838832 0.10796154 -1.2226542 ]
all at once: [ 0.18693547 -1.2806505 -1.5593132 ]
缺乏顺序等价性使我们可以更高效地编写代码;例如,我们可以使用 jax.vmap()
以向量化的方式计算相同的结果,而不是通过顺序循环生成上面的 sequence
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [-0.04838832 0.10796154 -1.2226542 ]
下一步#
有关 JAX 随机数的更多信息,请参阅 jax.random
模块的文档。如果您对 JAX 随机数生成器的设计细节感兴趣,请参阅 JAX PRNG 设计。