jax.random 模块#

伪随机数生成的实用程序。

The jax.random 包提供了许多用于确定性生成伪随机数序列的例程。

基本用法#

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  

PRNG 密钥#

与 NumPy 和 SciPy 用户习惯的有状态伪随机数生成器 (PRNG) 不同,JAX 随机函数都需要一个显式的 PRNG 状态作为第一个参数传递。随机状态由一个特殊的数组元素类型描述,我们称之为密钥,通常由 jax.random.key() 函数生成

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]

然后,此密钥可以在 JAX 的任何随机数生成例程中使用

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

请注意,使用密钥不会修改它,因此重复使用相同的密钥将导致相同的结果

>>> random.uniform(key)
Array(0.41845703, dtype=float32)

如果您需要新的随机数,可以使用 jax.random.split() 生成新的子密钥

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)

注意

类型化的密钥数组,其元素类型如上所示的 key<fry>,是在 JAX v0.4.16 中引入的。在此之前,密钥通常以 uint32 数组表示,其最终维度表示密钥的位级表示。

密钥数组的两种形式仍然可以使用jax.random模块创建和使用。新风格的类型化密钥数组使用jax.random.key()创建。旧版uint32密钥数组使用jax.random.PRNGKey()创建。

要在这两者之间进行转换,请使用jax.random.key_data()jax.random.wrap_key_data()。当与 JAX 之外的系统交互(例如,将数组导出为可序列化格式)或将密钥传递给假设旧版格式的基于 JAX 的库时,可能需要旧版密钥格式。

否则,建议使用类型化密钥。相对于类型化密钥,旧版密钥的注意事项包括

  • 它们具有额外的尾随维度。

  • 它们具有数值数据类型(uint32),允许进行通常不应在密钥上执行的操作,例如整数算术。

  • 它们不携带有关 RNG 实现的信息。当旧版密钥传递给jax.random函数时,全局配置设置决定 RNG 实现(请参见下面的“高级 RNG 配置”)。

要详细了解此升级以及密钥类型的设计,请参阅JEP 9263

高级#

设计与背景#

TLDR:JAX PRNG = Threefry 计数器 PRNG + 面向函数的数组拆分模型

有关更多详细信息,请参阅docs/jep/263-prng.md

概括地说,除其他要求外,JAX PRNG 的目标是

  1. 确保可重复性,

  2. 良好地并行化,包括矢量化(生成数组值)和多副本、多核计算。特别是,它不应在随机函数调用之间使用顺序约束。

高级 RNG 配置#

JAX 提供了多种 PRNG 实现。可以使用jax.random.key的可选impl关键字参数选择特定的实现。当没有将impl选项传递给key构造函数时,实现由全局jax_default_prng_impl配置标志确定。可用实现的字符串名称为

  • "threefry2x32"(**默认**):基于 Threefry 哈希函数变体的基于计数器的 PRNG,如Salmon 等人 2011 年的这篇论文中所述。

  • "rbg""unsafe_rbg"(**实验性**):建立在XLA 的随机位生成器 (RBG) 算法之上的 PRNG。

    • "rbg"使用 XLA RBG 进行随机数生成,而对于密钥派生(如jax.random.splitjax.random.fold_in),它使用与"threefry2x32"相同的方法。

    • "unsafe_rbg"同时使用 XLA RBG 进行生成和密钥派生。

    这些实验方案生成的随机数尚未经过经验随机性测试(例如 BigCrush)。

    "unsafe_rbg"中的密钥派生也尚未经过经验测试。名称强调“不安全”,因为密钥派生质量和生成质量尚不清楚。

    此外,"rbg""unsafe_rbg"jax.vmap下表现异常。当对一批密钥上的随机函数进行 vmap 时,其输出值可能与其对相同密钥的真实映射不同。相反,在vmap下,整个批次的输出随机数都是仅从输入密钥批次中的第一个密钥生成的。例如,如果keys是 8 个密钥的向量,则jax.vmap(jax.random.normal)(keys)等于jax.random.normal(keys[0], shape=(8,))。这种特殊性反映了解决 XLA RBG 批处理支持有限的问题。

使用默认 RNG 的替代方案的原因包括

  1. 它可能难以编译为 TPU。

  2. 它在 TPU 上执行的速度相对较慢。

自动分区

为了使jax.jit能够有效地自动分区生成分片随机数数组(或密钥数组)的函数,所有 PRNG 实现都需要额外的标志

  • 对于"threefry2x32""rbg"密钥派生,请设置jax_threefry_partitionable=True

  • 对于"unsafe_rbg""rbg"随机数生成,请设置 XLA 标志--xla_tpu_spmd_rng_bit_generator_unsafe=1

可以使用XLA_FLAGS环境变量设置 XLA 标志,例如XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

有关jax_threefry_partitionable的更多信息,请参阅https://jax.ac.cn/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

摘要

属性

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

TPU 上最快

可有效分片(使用 pjit)

跨分片相同

跨 CPU/GPU/TPU 相同

密钥上的精确jax.vmap

(*): 设置jax_threefry_partitionable=1

(**): 设置XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

API 参考#

密钥创建与操作#

key(seed, *[, impl])

给定整数种子创建伪随机数生成器 (PRNG) 密钥。

key_data(keys)

恢复 PRNG 密钥数组底层的密钥数据位。

wrap_key_data(key_bits_array, *[, impl])

将密钥数据位的数组包装到 PRNG 密钥数组中。

fold_in(key, data)

将数据折叠到 PRNG 密钥中以形成新的 PRNG 密钥。

split(key[, num])

通过添加前导轴将 PRNG 密钥拆分为num个新密钥。

clone(key)

克隆一个密钥以供重复使用

PRNGKey(seed, *[, impl])

给定整数种子创建旧版 PRNG 密钥。

随机采样器#

ball(key, d[, p, shape, dtype])

从单位 Lp 球体中均匀采样。

bernoulli(key[, p, shape])

使用给定的形状和均值采样伯努利随机值。

beta(key, a, b[, shape, dtype])

使用给定的形状和浮点数据类型采样 Beta 随机值。

binomial(key, n, p[, shape, dtype])

使用给定的形状和浮点数据类型采样二项式随机值。

bits(key[, shape, dtype])

以无符号整数的形式采样均匀位。

categorical(key, logits[, axis, shape])

从分类分布中采样随机值。

cauchy(key[, shape, dtype])

使用给定的形状和浮点数据类型采样柯西随机值。

chisquare(key, df[, shape, dtype])

使用给定的形状和浮点数据类型采样卡方随机值。

choice(key, a[, shape, replace, p, axis])

从给定数组中生成随机样本。

dirichlet(key, alpha[, shape, dtype])

使用给定的形状和浮点数据类型采样狄利克雷随机值。

double_sided_maxwell(key, loc, scale[, ...])

从双侧麦克斯韦分布中采样。

exponential(key[, shape, dtype])

以给定的形状和浮点类型采样指数随机值。

f(key, dfnum, dfden[, shape, dtype])

以给定的形状和浮点类型采样F分布随机值。

gamma(key, a[, shape, dtype])

以给定的形状和浮点类型采样Gamma随机值。

generalized_normal(key, p[, shape, dtype])

从广义正态分布中采样。

geometric(key, p[, shape, dtype])

以给定的形状和浮点类型采样几何随机值。

gumbel(key[, shape, dtype])

以给定的形状和浮点类型采样Gumbel随机值。

laplace(key[, shape, dtype])

以给定的形状和浮点类型采样拉普拉斯随机值。

loggamma(key, a[, shape, dtype])

以给定的形状和浮点类型采样对数伽马随机值。

logistic(key[, shape, dtype])

以给定的形状和浮点类型采样逻辑斯蒂随机值。

lognormal(key[, sigma, shape, dtype])

以给定的形状和浮点类型采样对数正态随机值。

maxwell(key[, shape, dtype])

从单侧麦克斯韦分布中采样。

multivariate_normal(key, mean, cov[, shape, ...])

以给定的均值和协方差采样多元正态随机值。

normal(key[, shape, dtype])

以给定的形状和浮点类型采样标准正态随机值。

orthogonal(key, n[, shape, dtype])

从正交群O(n)中均匀采样。

pareto(key, b[, shape, dtype])

以给定的形状和浮点类型采样帕累托随机值。

permutation(key, x[, axis, independent])

返回一个随机排列的数组或范围。

poisson(key, lam[, shape, dtype])

以给定的形状和整数类型采样泊松随机值。

rademacher(key[, shape, dtype])

从Rademacher分布中采样。

randint(key, shape, minval, maxval[, dtype])

以给定的形状/类型在[minval, maxval)中采样均匀随机值。

rayleigh(key, scale[, shape, dtype])

以给定的形状和浮点类型采样瑞利随机值。

t(key, df[, shape, dtype])

以给定的形状和浮点类型采样学生t分布随机值。

triangular(key, left, mode, right[, shape, ...])

以给定的形状和浮点类型采样三角形随机值。

truncated_normal(key, lower, upper[, shape, ...])

以给定的形状和类型采样截断的标准正态随机值。

uniform(key[, shape, dtype, minval, maxval])

以给定的形状/类型在[minval, maxval)中采样均匀随机值。

wald(key, mean[, shape, dtype])

以给定的形状和浮点类型采样沃尔德随机值。

weibull_min(key, scale, concentration[, ...])

从威布尔分布中采样。