伪随机数#
如果所有因错误的
rand()
而导致结果存疑的科学论文都从书架上消失,那么每个书架上都会留下一个大约拳头大小的空隙。 - 数值方法
在本节中,我们将重点介绍jax.random
和伪随机数生成 (PRNG);也就是说,通过算法生成数字序列的过程,这些数字序列的属性近似于从适当分布中采样的随机数序列的属性。
PRNG 生成的序列不是真正的随机数,因为它们实际上是由其初始值决定的,通常称为种子
,并且随机采样的每一步都是从一个样本到下一个样本传递的某个状态
的确定性函数。
伪随机数生成是任何机器学习或科学计算框架的基本组成部分。一般来说,JAX 力求与 NumPy 兼容,但伪随机数生成是一个值得注意的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方面所采用的方法之间的差异,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 通过numpy.random
模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于全局状态
,可以使用numpy.random.seed()
将其设置为确定性的初始条件。
import numpy as np
np.random.seed(0)
重复调用 NumPy 的有状态伪随机数生成器 (PRNG) 会改变全局状态,并给出伪随机数流
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439
在底层,NumPy 使用梅森旋转算法 PRNG 来驱动其伪随机函数。PRNG 的周期为\(2^{19937}-1\),并且在任何时候都可以用 624 个 32 位无符号整数和一个指示此“熵”使用了多少的位置来描述。
您可以使用以下命令检查状态的内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
状态
通过每次调用随机函数进行更新
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您在单个函数调用中采样单个数字或整个数字向量。例如,您可以通过执行以下操作从均匀分布中采样一个包含 3 个标量的向量
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供顺序等价保证,这意味着连续单独采样 N 个数字或采样一个包含 N 个数字的向量会产生相同的伪随机序列
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成与 NumPy 的随机数生成在重要方面有所不同,因为 NumPy 的 PRNG 设计使得难以同时保证许多理想的属性。具体来说,在 JAX 中,我们希望 PRNG 生成具有以下特性
可重现性,
可并行化,
可向量化。
我们将在下面讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的含义。考虑以下代码
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
函数foo
对从均匀分布中采样的两个标量求和。
如果我们假设bar()
和baz()
具有可预测的执行顺序,则此代码的输出只能满足要求 #1。这在 NumPy 中不是问题,NumPy 总是按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这更有问题:为了高效执行,我们希望 JIT 编译器可以自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程都需要同步全局状态,这会妨碍执行效率。
显式随机状态#
为了避免这些问题,JAX 避免了隐式全局随机状态,而是通过随机key
显式地跟踪状态
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用由jax.random.key()
生成的新式类型化 PRNG 密钥,而不是由jax.random.PRNGKey()
生成的旧式原始 PRNG 密钥。有关详细信息,请参阅JEP 9263:类型化密钥和可插拔 RNG。
密钥是一个数组,其具有与正在使用的特定 PRNG 实现相对应的特殊 dtype;在默认实现中,每个密钥都由一对uint32
值支持。
密钥实际上是 NumPy 隐藏的状态对象的替代品,但是我们将其显式传递给jax.random()
函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象馈送到随机函数总是会生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616
重复使用相同的密钥,即使使用不同的random
API,也可能导致相关输出,这通常是不希望的。
经验法则是:永远不要重复使用密钥(除非您想要相同的输出)。
JAX 使用现代基于 Threefry 计数器的 PRNG,它是可拆分的。也就是说,它的设计允许我们将 PRNG 状态分叉到新的 PRNG 中,以用于并行随机生成。为了生成不同且独立的样本,您必须在将其传递给随机函数之前显式地split()
密钥
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149
(这里不需要调用del
,但我们这样做是为了强调密钥一旦被消耗就不应重复使用。)
jax.random.split()
是一个确定性函数,它将一个key
转换为几个独立的(在伪随机意义上)密钥。我们将其中一个输出保留为new_key
,并且可以安全地使用唯一的额外密钥(称为subkey
)作为随机函数的输入,然后永远丢弃它。如果您想从正态分布中获取另一个样本,您将再次拆分key
,依此类推:关键是您永远不要重复使用同一个密钥两次。
我们称split(key)
的哪个输出为key
以及哪个输出为subkey
并不重要。它们都是具有相同地位的独立密钥。密钥/子密钥命名约定是一种典型的使用模式,有助于跟踪密钥的使用方式:子密钥注定要被随机函数立即消耗,而密钥则保留以稍后生成更多随机性。
通常,上面的示例将简明地写为
key, subkey = random.split(key)
它会自动丢弃旧密钥。值得注意的是,split()
可以创建您需要的任意多个密钥,而不仅仅是 2 个
key, *forty_two_subkeys = random.split(key, num=43)
缺乏顺序等价性#
NumPy 和 JAX 的随机模块之间的另一个区别与上面提到的顺序等价保证有关。
与 NumPy 中一样,JAX 的随机模块也允许采样数字向量。然而,JAX 不提供顺序等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上面的要求 #3)。
在下面的示例中,使用三个子密钥单独采样正态分布中的 3 个值会得到不同的结果,而使用单个密钥并指定shape=(3,)
则会得到不同的结果
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once: [-0.02830462 0.46713185 0.29570296]
缺乏顺序等价性使我们可以自由地更有效地编写代码;例如,我们可以使用jax.vmap()
以向量化的方式计算相同的结果,而不是通过顺序循环生成上面的sequence
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]
下一步#
有关 JAX 随机数的更多信息,请参阅jax.random
模块的文档。如果您对 JAX 随机数生成器设计的详细信息感兴趣,请参阅JAX PRNG 设计。