JEP 9263:类型化键和可插拔 RNG#

Jake VanderPlas, Roy Frostig

2023 年 8 月

概述#

展望未来,JAX 中的 RNG 键将更加类型安全且可定制。它们不再用长度为 2 的 uint32 数组表示单个 PRNG 键,而是用具有特殊 RNG 数据类型的标量数组表示,该数据类型满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)

目前,仍然可以使用 jax.random.PRNGKey() 创建旧式 RNG 键。

>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')

从现在开始,可以使用 jax.random.key() 创建新式 RNG 键。

>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>

这个(标量形状的)数组的行为与任何其他 JAX 数组相同,除了它的元素类型是一个键(以及相关元数据)。我们也可以创建非标量键数组,例如通过对 jax.vmap() 应用 jax.random.key()

>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
 [0 1]
 [0 2]
 [0 3]]
>>> key_arr.shape
(4,)

除了切换到新的构造函数之外,大多数与 PRNG 相关的代码应该继续按预期工作。您可以像以前一样在 jax.random API 中继续使用键;例如

# split
new_key, subkey = jax.random.split(key)

# random number generation
data = jax.random.uniform(key, shape=(5,))

但是,并非所有数值运算都适用于键数组。现在它们有意地引发错误

>>> key = key + 1  
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.

如果出于某种原因,你需要恢复底层缓冲区(旧式密钥),你可以使用 jax.random.key_data() 来实现。

>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)

对于旧式密钥,key_data() 是一个恒等操作。

这对用户意味着什么?#

对于 JAX 用户来说,此更改目前不需要任何代码更改,但我们希望你发现升级是值得的,并切换到使用类型化密钥。 要尝试这一点,请将 jax.random.PRNGKey() 的用法替换为 jax.random.key()。 这可能会在你的代码中引入一些问题,这些问题可以分为以下几类。

  • 如果你的代码对密钥执行不安全或不支持的操作(例如索引、算术运算、转置等;请参阅下面的类型安全性部分),此更改将捕获它们。 您可以更新您的代码以避免此类不支持的操作,或者使用 jax.random.key_data()jax.random.wrap_key_data() 以不安全的方式操作原始密钥缓冲区。

  • 如果你的代码包含关于 key.shape 的显式逻辑,你可能需要更新此逻辑以考虑到尾随密钥缓冲区维度不再是形状的显式部分。

  • 如果你的代码包含关于 key.dtype 的显式逻辑,你需要升级它以使用新的公共 API 来推理 RNG 数据类型,例如 dtypes.issubdtype(dtype, dtypes.prng_key)

  • 如果你调用了一个尚未处理类型化 PRNG 密钥的 JAX 库,你现在可以使用 raw_key = jax.random.key_data(key) 来恢复原始缓冲区,但请保留 TODO,以便在底层库支持类型化 RNG 密钥后将其删除。

在将来的某个时间点,我们计划弃用 jax.random.PRNGKey() 并要求使用 jax.random.key().

检测新式类型化密钥#

要检查一个对象是否是一个新式类型化 PRNG 密钥,可以使用 jax.dtypes.issubdtypejax.numpy.issubdtype

>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False

PRNG 密钥的类型注释#

对于旧式和新式 PRNG 密钥,推荐的类型注释是 jax.Array。 一个 PRNG 密钥根据其 dtype 与其他数组区分开来,目前无法在类型注释中指定 JAX 数组的数据类型。 以前,可以使用 jax.random.KeyArrayjax.random.PRNGKeyArray 作为类型注释,但这些在类型检查下始终被别名为 Any,因此 jax.Array 具有更高的特异性。

注意:jax.random.KeyArrayjax.random.PRNGKeyArray 在 JAX 0.4.16 版本中被弃用,并在 JAX 0.4.24 版本中被移除。.

面向 JAX 库作者的说明#

如果你维护一个基于 JAX 的库,你的用户也是 JAX 用户。 请注意,JAX 目前将继续在 jax.random 中支持“原始”旧式密钥,因此调用者可能希望它们在所有地方都被接受。 如果你希望你的库要求使用新式类型化密钥,那么你可能希望使用以下类似的检查来强制执行它们。

from jax import dtypes

def ensure_typed_key_array(key: Array) -> Array:
  if dtypes.issubdtype(key.dtype, dtypes.prng_key):
    return key
  else:
    raise TypeError("New-style typed JAX PRNG keys required")

动机#

此更改的两个主要驱动因素是可定制性和安全性。

定制 PRNG 实现#

JAX 目前使用单个全局配置的 PRNG 算法运行。 一个 PRNG 密钥是一个无符号 32 位整数向量,jax.random API 使用它来生成伪随机流。 任何更高秩的 uint32 数组都被解释为这样的密钥缓冲区数组,其中尾随维度表示密钥。

当我们引入替代 PRNG 实现时,这种设计的缺点变得更加明显,这些实现必须通过设置全局或局部配置标志来选择。 不同的 PRNG 实现具有不同大小的密钥缓冲区,以及不同的生成随机位的算法。 使用全局标志确定这种行为很容易出错,尤其是在进程范围内存在多个密钥实现时。

我们的新方法是将实现作为 PRNG 密钥类型的一部分,即使用密钥数组的元素类型。 使用新的密钥 API,以下是在默认 threefry2x32 实现(在纯 Python 中实现并使用 JAX 编译)和非默认 rbg 实现(对应于单个 XLA 随机位生成操作)下生成伪随机值的示例。

>>> key = jax.random.key(0, impl='threefry2x32')  # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32)

>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)

安全 PRNG 密钥使用#

原则上,PRNG 密钥只支持少数操作,即密钥派生(例如拆分)和随机数生成。 PRNG 被设计为生成独立的伪随机数,前提是密钥被正确拆分并且每个密钥只被使用一次。

以其他方式操作或使用密钥数据的代码通常表明存在意外错误,而将密钥数组表示为原始 uint32 缓冲区允许沿这些路线轻松地误用。 以下是一些我们在实践中遇到的误用示例。

密钥缓冲区索引#

访问底层整数缓冲区使得尝试以非标准方式派生密钥变得容易,有时会带来意外的严重后果。

# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1])  # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)

如果这个密钥是用 random.key(999) 创建的新式类型化密钥,那么索引到密钥缓冲区将导致错误。

密钥算术运算#

密钥算术运算是一种同样危险的方式,可以从其他密钥派生密钥。 以避免 jax.random.split()jax.random.fold_in() 的方式直接操作密钥数据,会产生一批密钥,这些密钥(取决于 PRNG 实现)可能会在批次中生成相关的随机数。

# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)

使用 random.key(0) 创建的新式类型化密钥通过禁止对密钥执行算术运算来解决这个问题。

意外转置密钥缓冲区#

使用“原始”旧式密钥数组,很容易意外地交换批次(前导)维度和密钥缓冲区(尾随)维度。 这同样会导致密钥可能生成相关的伪随机性。 我们一直在观察到的一个模式归结为以下内容。

# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)

这里的错误很微妙。 通过在 in_axes=1 上进行映射,此代码通过组合批次中每个密钥缓冲区中的单个元素来生成新的密钥。 生成的密钥彼此不同,但实际上是以非标准方式“派生”的。 同样,PRNG 不是设计或测试过从这样的密钥批次生成独立的随机流。

使用 random.key(0) 创建的新式类型化密钥通过隐藏单个密钥的缓冲区表示来解决这个问题,而是将密钥视为密钥数组的不透明元素。 密钥数组没有尾随的“缓冲区”维度可以进行索引、转置或映射。

密钥重用#

与像 numpy.random 这样的基于状态的 PRNG API 不同,JAX 的函数式 PRNG 不会在密钥使用后隐式更新密钥。

# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,))  # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))

我们正在积极努力开发工具来检测和防止意外的密钥重用。 这项工作仍在进行中,但它依赖于类型化密钥数组。 现在升级到类型化密钥为我们提供了一个基础,以便在我们构建这些安全功能时引入它们。

类型化 PRNG 密钥的设计#

类型化 PRNG 密钥被实现为 JAX 中扩展数据类型的实例,其中新的 PRNG 数据类型是其子数据类型。

扩展数据类型#

从用户的角度来看,扩展数据类型 dt 具有以下用户可见的属性。

  • jax.dtypes.issubdtype(dt, jax.dtypes.extended) 返回 True:这是应该用于检测数据类型是否为扩展数据类型的公共 API。

  • 它有一个类级属性 dt.type,它返回 numpy.generic 层次结构中的一个类型类。 这类似于 np.dtype('int32').type 如何返回 numpy.int32,它不是数据类型,而是一个标量类型,并且是 numpy.generic 的子类。

  • 与 numpy 标量类型不同,我们不允许实例化 dt.type 标量对象:这与 JAX 将标量值表示为零维数组的决定一致。

从非公共实现的角度来看,扩展数据类型具有以下属性。

  • 它的类型是私有基类 jax._src.dtypes.ExtendedDtype 的子类,它是用于扩展数据类型的非公共基类。 ExtendedDtype 的实例类似于 np.dtype 的实例,例如 np.dtype('int32')

  • 它有一个私有 _rules 属性,它允许数据类型定义它在特定操作下的行为方式。 例如,jax.lax.full(shape, fill_value, dtype)dtype 是一个扩展数据类型时,将委托给 dtype._rules.full(shape, fill_value, dtype)

为什么除了 PRNG 之外,还要在通用情况下引入扩展数据类型?我们会在内部其他地方重复使用相同的扩展数据类型机制。例如,jax._src.core.bint 对象,一种用于动态形状实验工作的有界整数类型,也是另一种扩展数据类型。在最近的 JAX 版本中,它满足上述属性(参见 jax/_src/core.py#L1789-L1802)。

PRNG 数据类型#

PRNG 数据类型被定义为扩展数据类型的一种特殊情况。具体来说,此更改引入了一种新的公共标量类型类 jax.dtypes.prng_key,它具有以下属性

>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True

然后,PRNG 键数组具有具有以下属性的数据类型

>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True

除了针对一般扩展数据类型概述的 key.dtype._rules 之外,PRNG 数据类型还定义了 key.dtype._impl,其中包含定义 PRNG 实现的元数据。PRNG 实现目前由非公共 jax._src.prng.PRNGImpl 类定义。目前,PRNGImpl 不打算成为公共 API,但我们可能会很快重新审视这一点,以允许完全自定义的 PRNG 实现。

进度#

以下是实现上述设计的关键 Pull 请求的不完整列表。主要跟踪问题是 #9263

  • 通过 PRNGImpl 实现可插拔 PRNG: #6899

  • 实现 PRNGKeyArray,没有数据类型: #11952

  • 使用 _rules 属性向 PRNGKeyArray 添加“自定义元素”数据类型属性: #12167

  • 将“自定义元素类型”重命名为“不透明数据类型”: #12170

  • 重构 bint 以使用不透明数据类型基础设施: #12707

  • 添加 jax.random.key 以直接创建类型化键: #16086

  • keyPRNGKey 添加 impl 参数: #16589

  • 将“不透明数据类型”重命名为“扩展数据类型”并定义 jax.dtypes.extended#16824

  • 引入 jax.dtypes.prng_key 并将 PRNG 数据类型与扩展数据类型统一: #16781

  • 添加 jax_legacy_prng_key 标志以支持在使用旧版(原始)PRNG 键时发出警告或错误: #17225