JEP 9263:类型化键和可插拔 RNG#
Jake VanderPlas,Roy Frostig
2023 年 8 月
概述#
未来,JAX 中的 RNG 密钥将具有更高的类型安全性和可定制性。它不再使用长度为 2 的 uint32
数组来表示单个 PRNG 密钥,而是使用具有特殊 RNG dtype 的标量数组来表示,该标量数组满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)
。
目前,仍可以使用 jax.random.PRNGKey()
创建旧式 RNG 密钥。
>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')
从现在开始,可以使用 jax.random.key()
创建新式 RNG 密钥。
>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>
这种(标量形状的)数组的行为与任何其他 JAX 数组相同,只是它的元素类型是一个密钥(以及相关的元数据)。我们也可以创建非标量密钥数组,例如通过将 jax.vmap()
应用于 jax.random.key()
来实现。
>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
[0 1]
[0 2]
[0 3]]
>>> key_arr.shape
(4,)
除了切换到新的构造函数外,大多数与 PRNG 相关的代码应继续按预期工作。你可以继续在 jax.random
API 中像以前一样使用密钥;例如
# split
new_key, subkey = jax.random.split(key)
# random number generation
data = jax.random.uniform(key, shape=(5,))
但是,并非所有数值运算都适用于密钥数组。现在,它们会故意引发错误
>>> key = key + 1
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.
如果出于某种原因需要恢复底层缓冲区(旧式密钥),可以使用 jax.random.key_data()
来实现。
>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)
对于旧式密钥,key_data()
是一个恒等运算。
这对用户意味着什么?#
对于 JAX 用户,此更改现在不需要任何代码更改,但我们希望你觉得这次升级是值得的,并切换到使用类型化的密钥。要尝试此操作,请将 jax.random.PRNGKey() 的使用替换为 jax.random.key()。这可能会在你的代码中引入一些故障,这些故障可分为以下几类
如果你的代码对密钥执行不安全/不支持的操作(例如索引、算术、转置等;请参阅下面的“类型安全”部分),此更改将捕获它。你可以更新代码以避免此类不支持的操作,或者使用
jax.random.key_data()
和jax.random.wrap_key_data()
以不安全的方式操作原始密钥缓冲区。如果你的代码包含有关
key.shape
的显式逻辑,则可能需要更新此逻辑以考虑尾随密钥缓冲区维度不再是形状的显式部分这一事实。如果你的代码包含有关
key.dtype
的显式逻辑,则需要将其升级为使用新的公共 API 来推理 RNG dtype,例如dtypes.issubdtype(dtype, dtypes.prng_key)
。如果调用尚未处理类型化 PRNG 密钥的基于 JAX 的库,则现在可以使用
raw_key = jax.random.key_data(key)
来恢复原始缓冲区,但请保留一个 TODO,以便在下游库支持类型化的 RNG 密钥后删除它。
在未来的某个时候,我们计划弃用 jax.random.PRNGKey()
并要求使用 jax.random.key()
。
检测新式类型化密钥#
要检查对象是否为新式类型化 PRNG 密钥,可以使用 jax.dtypes.issubdtype
或 jax.numpy.issubdtype
>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
PRNG 密钥的类型注释#
新旧两种样式的 PRNG 密钥的推荐类型注释均为 jax.Array
。PRNG 密钥根据其 dtype
与其他数组区分开来,并且当前无法在类型注释中指定 JAX 数组的 dtype。以前可以使用 jax.random.KeyArray
或 jax.random.PRNGKeyArray
作为类型注释,但这些在类型检查下始终被别名为 Any
,因此 jax.Array
具有更高的特异性。
注意:jax.random.KeyArray
和 jax.random.PRNGKeyArray
已在 JAX 版本 0.4.16 中弃用,并在 JAX 版本 0.4.24 中删除.
动机#
此更改的两个主要动机是可定制性和安全性。
自定义 PRNG 实现#
JAX 当前使用单个全局配置的 PRNG 算法运行。PRNG 密钥是无符号 32 位整数的向量,jax.random API 使用这些向量来生成伪随机流。任何更高秩的 uint32 数组都被解释为此类密钥缓冲区的数组,其中尾随维度表示密钥。
当我们引入替代 PRNG 实现时,此设计的缺点变得更加明显,必须通过设置全局或局部配置标志来选择这些实现。不同的 PRNG 实现具有不同大小的密钥缓冲区和不同的随机位生成算法。使用全局标志确定此行为容易出错,尤其是在整个进程中使用多个密钥实现时。
我们的新方法是将实现作为 PRNG 密钥类型的一部分进行携带,即使用密钥数组的元素类型。使用新的密钥 API,以下示例说明了在默认的 threefry2x32 实现(在纯 Python 中实现并使用 JAX 编译)和非默认的 rbg 实现(对应于单个 XLA 随机位生成操作)下生成伪随机值
>>> key = jax.random.key(0, impl='threefry2x32') # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)
>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
安全的 PRNG 密钥使用#
原则上,PRNG 密钥实际上只支持几个操作,即密钥派生(例如拆分)和随机数生成。PRNG 旨在生成独立的伪随机数,前提是密钥已正确拆分且每个密钥仅使用一次。
以其他方式操作或使用密钥数据的代码通常表示意外的错误,并且将密钥数组表示为原始 uint32 缓冲区使沿着这些方向的滥用变得容易。以下是一些我们在野外遇到的滥用示例
密钥缓冲区索引#
访问底层整数缓冲区可以很容易地尝试以非标准方式派生密钥,有时会导致意想不到的糟糕后果
# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1]) # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)
如果此密钥是使用 random.key(999)
创建的新式类型化密钥,则索引到密钥缓冲区中会出错。
密钥算术#
密钥算术是另一种从其他密钥派生密钥的危险方式。以避免使用 jax.random.split()
或 jax.random.fold_in()
通过直接操作密钥数据的方式派生密钥会生成一批密钥,这些密钥(取决于 PRNG 实现)可能会在批次内生成相关的随机数
# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)
使用 random.key(0)
创建的新式类型化密钥通过禁止对密钥执行算术运算来解决此问题。
密钥缓冲区的意外转置#
使用“原始”旧式密钥数组,很容易意外地交换批次(前导)维度和密钥缓冲区(尾随)维度。同样,这可能会导致密钥产生相关的伪随机性。我们随着时间的推移看到的模式可以归结为以下几点
# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)
这里的错误很微妙。通过映射 in_axes=1
,此代码通过组合批次中每个密钥缓冲区中的单个元素来生成新密钥。生成的密钥彼此不同,但实际上是以非标准方式“派生”的。同样,PRNG 没有设计或测试为从这样的密钥批次生成独立的随机流。
使用 random.key(0)
创建的新式类型化密钥通过隐藏单个密钥的缓冲区表示来解决此问题,而是将密钥视为密钥数组的不透明元素。密钥数组没有尾随的“缓冲区”维度来索引、转置或映射。
密钥重用#
与基于状态的 PRNG API(如 numpy.random
)不同,JAX 的函数式 PRNG 在使用密钥后不会隐式更新密钥。
# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,)) # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))
我们正在积极开发工具来检测和防止意外的密钥重用。这仍在进行中,但它依赖于类型化的密钥数组。现在升级到类型化密钥使我们能够在构建这些安全功能时引入它们。
类型化 PRNG 密钥的设计#
类型化的 PRNG 密钥在 JAX 中实现为扩展 dtype 的实例,其中新的 PRNG dtype 是子 dtype。
扩展 dtype#
从用户角度来看,扩展 dtype dt 具有以下用户可见的属性
jax.dtypes.issubdtype(dt, jax.dtypes.extended)
返回True
:这是用于检测 dtype 是否为扩展 dtype 的公共 API。它有一个类级别的属性
dt.type
,它返回numpy.generic
层次结构中的一个类型类。这类似于np.dtype('int32').type
如何返回numpy.int32
,后者不是一个 dtype,而是一个标量类型,并且是numpy.generic
的子类。与 numpy 标量类型不同,我们不允许实例化
dt.type
标量对象:这符合 JAX 将标量值表示为零维数组的决定。
从非公开的实现角度来看,扩展的 dtype 具有以下属性:
它的类型是私有基类
jax._src.dtypes.ExtendedDtype
的子类,后者是用于扩展 dtype 的非公开基类。ExtendedDtype
的实例类似于np.dtype
的实例,例如np.dtype('int32')
。它有一个私有的
_rules
属性,允许 dtype 定义其在特定操作下的行为。例如,当dtype
是扩展的 dtype 时,jax.lax.full(shape, fill_value, dtype)
将委托给dtype._rules.full(shape, fill_value, dtype)
。
为什么要引入通用的扩展 dtype,而不仅仅是 PRNG?我们在内部的其他地方也重用了相同的扩展 dtype 机制。例如,jax._src.core.bint
对象,一种用于动态形状实验的边界整数类型,是另一个扩展的 dtype。在最近的 JAX 版本中,它满足上述属性(参见 jax/_src/core.py#L1789-L1802)。
PRNG dtypes#
PRNG dtypes 被定义为扩展 dtype 的一个特殊情况。具体来说,此更改引入了一个新的公共标量类型类 jax.dtypes.prng_key,它具有以下属性:
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True
然后,PRNG 密钥数组具有具有以下属性的 dtype:
>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True
除了如上所述的扩展 dtype 通用的 key.dtype._rules
之外,PRNG dtypes 还定义了 key.dtype._impl
,其中包含定义 PRNG 实现的元数据。PRNG 实现目前由非公开的 jax._src.prng.PRNGImpl
类定义。目前,PRNGImpl
并非旨在成为公共 API,但我们可能会很快重新考虑这一点,以允许完全自定义的 PRNG 实现。
进展#
以下是实现上述设计的关键 Pull Requests 的非详尽列表。主要跟踪问题是 #9263。
通过
PRNGImpl
实现可插拔的 PRNG:#6899实现
PRNGKeyArray
,不带 dtype:#11952向
PRNGKeyArray
添加具有_rules
属性的“自定义元素”dtype 属性:#12167将“自定义元素类型”重命名为“不透明 dtype”:#12170
重构
bint
以使用不透明 dtype 基础设施:#12707添加
jax.random.key
以直接创建类型化的密钥:#16086向
key
和PRNGKey
添加impl
参数:#16589将“不透明 dtype”重命名为“扩展 dtype”并定义
jax.dtypes.extended
:#16824引入
jax.dtypes.prng_key
并将 PRNG dtype 与扩展 dtype 统一:#16781添加一个
jax_legacy_prng_key
标志,以支持在使用旧版(原始)PRNG 密钥时发出警告或错误:#17225