JAX PRNG 设计#
我们希望一个 PRNG 设计
具有**表达力**,即使用方便且不会限制用户编写具有所需精确行为的数值程序的能力,
能够以与后端无关的方式实现程序执行的**可重复性**,
具有**对 @jit 编译边界和设备后端不变的语义**,
能够使用 SIMD 硬件**矢量化生成数组值**,
是**可并行化的**,因为它不会在随机函数调用之间添加顺序约束,否则这些调用将没有任何数据依赖关系,
扩展到**多副本、多核心和分布式计算**,
**符合 JAX 和 XLA 的语义**和设计理念(这些理念最终是由其他实际问题推动的)。
作为这些的推论,我们认为设计应该是函数式的。另一个推论是,至少在当前的硬件约束下,我们将使用软件进行 PRNG。
TLDR **JAX PRNG = Threefry 计数器 PRNG + 一个面向数组的函数式 拆分模型**
内容#
三种编程模型和玩具示例程序#
这是一个**有状态全局** PRNG 的玩具示例,类似于 Numpy 程序中常用的那种
def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
global RNG
RNG = RandomState(0)
return foo()
为了实现可复现性,我们需要控制bar()和baz()的评估顺序,即使它们之间没有显式的数据依赖关系。这种源于可复现性(#2)的顺序要求违反了并行化(#5),并且不符合JAX或XLA的函数式语义(#6),在函数式语义中,子表达式可以以任何顺序进行评估。即使我们不需要可复现性,从而允许任何评估顺序,跨调用的并行化(#5)仍然会因为需要更新共享状态而变得困难。此外,由于相同的PRNG状态需要在Python和任何编译代码中访问和维护,因此该模型可能会导致工程挑战,以实现编译不变性(#3)和扩展到多个副本(#6)。最后,表达能力有限(#1),因为foo()无法调用bar()或baz()而不会影响其自身(隐式)的PRNG状态。
模型是否支持向量化(#4)取决于一些额外的细节。在NumPy中,PRNG向量化受限于一个**顺序等价保证**。
In [1]: rng = np.random.RandomState(0)
In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])
In [3]: rng = np.random.RandomState(0)
In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])
为了允许在生成数组的原始PRNG函数调用中进行向量化(#4)(例如,使用形状参数调用rand()),我们放弃了这个顺序等价保证。任何本节讨论的三种编程模型都可以支持这种向量化,尽管它促使了基于计数器的PRNG的实现,如下一节所述。
有状态的PRNG用户编程模型并不理想。这是一个函数式模型的示例,但缺少我们称之为“分裂”的关键成分。
def foo(rng_1):
y, rng_2 = baz(rng_1)
z, rng_3 = bar(rng_2)
return y + z, rng_3
def bar(x, rng):
val, new_rng = rand(rng, (3, 4))
return val, new_rng
def baz(x, rng):
val, new_rng = rand(rng, (3, 4))
return val, new_rng
def main():
foo(RandomState(0))
此模型将PRNG状态显式地贯穿所有生成随机值的函数(原始或非原始):也就是说,每个随机函数都必须接受和返回状态。现在,foo()中对baz()的调用和对bar()的调用之间存在显式的数据依赖关系,因此数据流(以及顺序)变得明确,并且符合JAX现有的语义(#7),这与之前的模型不同。这种显式贯穿还可以使语义对编译边界(#3)保持不变。
显式贯穿对程序员来说很不方便。但更糟糕的是,它实际上并没有提高表达能力(#1):foo()仍然无法在调用bar()或baz()的同时维护其自己的PRNG状态。在不知道其调用者或其调用的子例程的情况下,函数必须在任何地方防御性地传入和返回rng状态。此外,它也没有改善并行化(#5)或扩展到多个副本(#6)的前景,因为即使在函数式编程意义上明确了顺序,所有内容仍然是顺序的。
简而言之,通过显式贯穿状态使代码函数化不足以实现我们的表达能力(#1)和性能(#5、#6)目标。
之前两种模型中的关键问题是顺序过多。为了减少顺序依赖的数量,我们使用**函数式可分裂PRNG**。分裂是一种机制,可以将一个新的PRNG状态“分叉”成两个PRNG状态,同时保持通常期望的PRNG属性(这两个新流在计算上是可并行化的,并产生独立的随机值,即它们的行为类似于**多流**)。
def foo(rng_1):
rng_2, rng_3 = split(rng_1, 2)
return bar(rng_2) + baz(rng_3)
def bar(x, rng):
return rand(rng, (3, 4))
def baz(x, rng):
return rand(rng, (3, 4))
def main():
foo(RandomState(0))
需要注意的一些要点
对bar()和baz()的调用之间没有顺序依赖关系,它们可以以任何顺序进行评估而不会影响结果的值,这解决了剩余的性能目标(#5、#6),
函数不需要返回PRNG的更新版本,并且可以轻松地调用随机子例程而不会影响现有的PRNG状态,从而提高了其他函数式模型的表达能力(#1)。
示例中没有显示,但作为选择(2)的结果,推进PRNG状态的唯一方法是调用split()。也就是说,我们有两种方法来实现(1),它们的区别在于它们是否将用户程序负担在显式调用split()上,如上述示例所示,或者将用户程序负担在显式贯穿中。我们更倾向于前者,即具有显式分裂的版本,因为我们可以轻松地根据它实现显式贯穿版本。
设计#
我们可以使用**基于计数器的PRNG**设计,特别是如《Parallel random numbers: as easy as 1, 2, 3》中所述的Threefry哈希函数。我们使用计数器来实现高效的向量化:对于给定的密钥,我们可以通过对整数范围[k + 1,…,k + sample_size]上的哈希函数进行映射,以向量化的方式生成一系列值。我们使用密钥和哈希函数来实现**可分裂PRNG**:也就是说,分裂是一种从现有密钥生成两个新密钥的方法。
type Sample = Int256
type Key = Sample -- important identification for splitting
type Count = Int32
hash :: Key -> Count -> Int256 -- output type equal to Key and Sample
split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)
draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]
令人惊讶的是,抽取样本与分裂非常相似!关键在于输出类型的差异(即使类型已识别):在一个案例中,该值用于形成感兴趣的随机样本(例如,将随机位转换为表示随机正态数的浮点数),而在另一个案例中,该值用作进一步哈希的密钥。
哈希函数参数(类型为Key和Count)中的不对称性在于后者是微不足道的,并且在计算上便宜地通过任意数量进行推进,因为我们只需要增加整数值,而前者仅通过哈希进行推进。这就是我们使用计数参数进行向量化的原因。
更真实的示例用户程序#
当步骤需要PRNG(可能是用于dropout或VAE训练)时,主机上的训练循环可能如下所示。
rng = lax.rng.new_rng()
for i in xrange(num_steps):
rng, rng_input = lax.rng.split(rng)
params = compiled_update(rng_input, params, next(batches))
请注意,我们给用户带来了rng显式分裂的负担,但rng根本不需要从代码中返回。
以下是如何使用此PRNG模型与stax神经网络构建器库来实现dropout。
def Dropout(rate, mode='train'):
def init_fun(input_shape):
return input_shape, ()
def apply_fun(rng, params, inputs):
if mode == 'train':
keep = lax.random.bernoulli(rng, rate, inputs.shape)
return np.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
此处的rng值只是用于哈希的密钥,而不是特殊对象。rng参数传递给每个apply_fun,因此需要在串行和并行组合器中使用分裂进行处理。
def serial(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
for rng, param, apply_fun in zip(rngs, params, apply_funs):
inputs = apply_fun(rng, param, inputs)
return inputs
return init_fun, apply_fun
def parallel(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
return init_fun, apply_fun
这里我们使用了一个简单的扩展版本split,它可以产生多个副本。
权衡和替代方案#
我们没有利用任何设备硬件PRNG。
我们目前对所有后端的硬件PRNG状态的控制力不足。
即使我们做到了,它也将依赖于后端,并且我们可能需要在随机调用之间引入顺序依赖关系以确保确定性排序,从而确保可复现性。
我们不知道任何软件PRNG会成为瓶颈的工作负载。
我们可以考虑提供一个额外的API,允许访问硬件PRNG,供希望放弃其他期望(如严格的可复现性)的用户使用。
我们放弃了顺序等价保证,在该保证中,在一个调用中创建随机数组会产生与逐个随机元素创建扁平化数组相同的值。
此属性可能与向量化(高优先级)不兼容。
我们不知道任何此属性很重要的用户或示例。
用户可以在此API之上编写一层来提供此保证。
我们不能完全遵循
numpy.random
API。