JEP 9263:类型化键和可插拔 RNG#
Jake VanderPlas, Roy Frostig
2023 年 8 月
概述#
展望未来,JAX 中的 RNG 键将更加类型安全和可定制。它将不再用长度为 2 的 uint32
数组表示单个 PRNG 键,而是用具有特殊 RNG 数据类型的标量数组表示,该数据类型满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)
。
目前,旧式的 RNG 键仍然可以使用 jax.random.PRNGKey()
创建
>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')
从现在开始,可以使用 jax.random.key()
创建新式 RNG 键
>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>
这种(标量形状的)数组的行为与任何其他 JAX 数组相同,只是它的元素类型是一个键(以及相关的元数据)。我们也可以创建非标量的键数组,例如通过将 jax.vmap()
应用于 jax.random.key()
。
>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
[0 1]
[0 2]
[0 3]]
>>> key_arr.shape
(4,)
除了切换到新的构造函数外,大多数与 PRNG 相关的代码应该继续按预期工作。您可以像以前一样继续在 jax.random
API 中使用键;例如
# split
new_key, subkey = jax.random.split(key)
# random number generation
data = jax.random.uniform(key, shape=(5,))
但是,并非所有数值运算都适用于键数组。现在它们会故意引发错误。
>>> key = key + 1
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.
如果由于某种原因您需要恢复底层缓冲区(旧式键),您可以使用 jax.random.key_data()
来实现。
>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)
对于旧式键,key_data()
是一个恒等运算。
这对用户意味着什么?#
对于 JAX 用户来说,此更改现在不需要进行任何代码更改,但我们希望您会发现升级值得,并切换到使用类型化的键。要尝试此功能,请将 jax.random.PRNGKey() 的使用替换为 jax.random.key()。这可能会在您的代码中引入一些属于以下几个类别的破坏。
如果您的代码对键执行不安全/不支持的操作(例如索引、算术、转置等;请参阅下面的“类型安全”部分),此更改将捕获它。您可以更新代码以避免此类不支持的操作,或者使用
jax.random.key_data()
和jax.random.wrap_key_data()
以不安全的方式操作原始键缓冲区。如果您的代码包含关于
key.shape
的显式逻辑,您可能需要更新此逻辑,以考虑尾随键缓冲区维度不再是形状的显式部分这一事实。如果您的代码包含关于
key.dtype
的显式逻辑,您将需要升级它以使用新的公共 API 来推理 RNG 数据类型,例如dtypes.issubdtype(dtype, dtypes.prng_key)
。如果您调用一个尚未处理类型化 PRNG 键的基于 JAX 的库,您现在可以使用
raw_key = jax.random.key_data(key)
来恢复原始缓冲区,但请保留一个 TODO,以便在下游库支持类型化 RNG 键后删除此代码。
在未来的某个时刻,我们计划弃用 jax.random.PRNGKey()
,并要求使用 jax.random.key()
。
检测新式类型化键#
要检查一个对象是否是新式类型化的 PRNG 键,您可以使用 jax.dtypes.issubdtype
或 jax.numpy.issubdtype
。
>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
PRNG 键的类型注释#
新旧式 PRNG 键的推荐类型注释都是 jax.Array
。PRNG 键是根据其 dtype
与其他数组区分开的,并且目前无法在类型注释中指定 JAX 数组的 dtype。以前可以使用 jax.random.KeyArray
或 jax.random.PRNGKeyArray
作为类型注释,但在类型检查下,它们始终被别名为 Any
,因此 jax.Array
具有更高的特异性。
注意:jax.random.KeyArray
和 jax.random.PRNGKeyArray
已在 JAX 版本 0.4.16 中被弃用,并在 JAX 版本 0.4.24 中删除。.
动机#
此更改的两个主要动机是可定制性和安全性。
自定义 PRNG 实现#
JAX 目前使用单个全局配置的 PRNG 算法进行操作。PRNG 键是一个无符号 32 位整数的向量,jax.random API 使用这些整数来生成伪随机流。任何更高阶的 uint32 数组都被解释为此类键缓冲区的数组,其中尾随维度表示键。
当我们引入替代的 PRNG 实现时,此设计的缺点变得更加明显,必须通过设置全局或局部配置标志来选择这些实现。不同的 PRNG 实现具有不同大小的键缓冲区和不同的随机位生成算法。使用全局标志确定此行为容易出错,尤其是在整个进程中使用多个键实现时。
我们的新方法是将实现作为 PRNG 键类型的一部分来携带,即作为键数组的元素类型。使用新的键 API,这是一个在默认的 threefry2x32 实现(在纯 Python 中实现并使用 JAX 编译)和非默认的 rbg 实现(对应于单个 XLA 随机位生成操作)下生成伪随机值的示例:
>>> key = jax.random.key(0, impl='threefry2x32') # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.947667 , 0.9785799 , 0.33229148], dtype=float32)
>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
安全使用 PRNG 键#
原则上,PRNG 键实际上只支持几个操作,即键派生(例如拆分)和随机数生成。PRNG 的设计目的是生成独立的伪随机数,前提是正确拆分键并且每个键只使用一次。
以其他方式操作或使用键数据的代码通常表示意外的错误,将键数组表示为原始 uint32 缓冲区使得沿着这些方向的滥用很容易发生。以下是我们在野外遇到的一些滥用示例:
键缓冲区索引#
访问底层整数缓冲区使得尝试以非标准方式派生键变得容易,有时会导致意想不到的糟糕后果。
# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1]) # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)
如果此键是使用 random.key(999)
创建的新式类型化键,则索引到键缓冲区会出错。
键算术#
键算术是一种类似危险的从其他键派生键的方式。以避免 jax.random.split()
或 jax.random.fold_in()
的方式直接操作键数据派生键,会生成一批键,这些键(取决于 PRNG 实现)可能会在该批次内生成相关的随机数。
# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)
使用 random.key(0)
创建的新式类型化键通过不允许对键进行算术运算来解决此问题。
意外转置键缓冲区#
对于“原始”旧式键数组,很容易意外地交换批次(前导)维度和键缓冲区(尾随)维度。同样,这可能导致生成相关的伪随机性的键。我们随着时间的推移看到的一种模式可以归结为以下内容:
# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)
这里的错误很微妙。通过映射 in_axes=1
,此代码通过组合批次中每个键缓冲区中的单个元素来创建新键。生成的键彼此不同,但实际上是以非标准方式“派生”的。同样,PRNG 的设计或测试目的不是从此类键批次生成独立的随机流。
使用 random.key(0)
创建的新式类型化键通过隐藏单个键的缓冲区表示来解决此问题,而是将键视为键数组的不透明元素。键数组没有要索引、转置或映射的尾随“缓冲区”维度。
键重用#
与基于状态的 PRNG API(如 numpy.random
)不同,JAX 的函数式 PRNG 在使用后不会隐式更新键。
# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,)) # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))
我们正在积极开发工具来检测和防止意外的键重用。这仍在进行中,但它依赖于类型化的键数组。现在升级到类型化的键为我们构建这些安全功能奠定了基础。
类型化 PRNG 键的设计#
类型化的 PRNG 键在 JAX 中作为扩展数据类型的实例实现,其中新的 PRNG 数据类型是子数据类型。
扩展数据类型#
从用户的角度来看,扩展数据类型 dt 具有以下用户可见的属性:
jax.dtypes.issubdtype(dt, jax.dtypes.extended)
返回True
:这是用于检测数据类型是否为扩展数据类型的公共 API。它具有类级别的属性
dt.type
,该属性返回numpy.generic
层次结构中的类型类。这类似于np.dtype('int32').type
返回numpy.int32
的方式,它不是数据类型,而是一个标量类型,并且是numpy.generic
的子类。与 numpy 标量类型不同,我们不允许实例化
dt.type
标量对象:这符合 JAX 将标量值表示为零维数组的决定。
从非公共实现的角度来看,扩展数据类型具有以下属性:
它的类型是私有基类
jax._src.dtypes.ExtendedDtype
的子类,这个非公开的基类用于扩展数据类型。ExtendedDtype
的实例类似于np.dtype
的实例,比如np.dtype('int32')
。它有一个私有的
_rules
属性,该属性允许数据类型定义它在特定操作下的行为。例如,当dtype
是扩展数据类型时,jax.lax.full(shape, fill_value, dtype)
将委托给dtype._rules.full(shape, fill_value, dtype)
。
为什么要在 PRNG 之外普遍引入扩展数据类型?我们在内部的其他地方也重用了这种相同的扩展数据类型机制。例如,jax._src.core.bint
对象,一种用于动态形状实验的有限整数类型,也是另一种扩展数据类型。在最近的 JAX 版本中,它满足上述属性(请参阅 jax/_src/core.py#L1789-L1802)。
PRNG 数据类型#
PRNG 数据类型被定义为扩展数据类型的特定情况。具体来说,此更改引入了一个新的公共标量类型类 jax.dtypes.prng_key,它具有以下属性
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True
PRNG 密钥数组然后具有以下属性的数据类型
>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True
除了为扩展数据类型概述的 key.dtype._rules
之外,PRNG 数据类型还定义了 key.dtype._impl
,其中包含定义 PRNG 实现的元数据。PRNG 实现目前由非公开的 jax._src.prng.PRNGImpl
类定义。目前,PRNGImpl
不打算作为公共 API,但我们可能会很快重新考虑这一点,以允许完全自定义的 PRNG 实现。
进展#
以下是实现上述设计的关键拉取请求的非详尽列表。主要的跟踪问题是 #9263。
通过
PRNGImpl
实现可插拔 PRNG:#6899实现
PRNGKeyArray
,不带数据类型:#11952向
PRNGKeyArray
添加具有_rules
属性的“自定义元素”数据类型属性:#12167将“自定义元素类型”重命名为“不透明数据类型”:#12170
重构
bint
以使用不透明数据类型基础设施:#12707添加
jax.random.key
以直接创建类型化密钥:#16086向
key
和PRNGKey
添加impl
参数:#16589将“不透明数据类型”重命名为“扩展数据类型”并定义
jax.dtypes.extended
:#16824引入
jax.dtypes.prng_key
并将 PRNG 数据类型与扩展数据类型统一:#16781添加
jax_legacy_prng_key
标志,以支持在使用旧式(原始)PRNG 密钥时发出警告或错误:#17225