jax.experimental.key_reuse 模块

jax.experimental.key_reuse 模块#

实验性密钥重用检查#

此模块包含用于检测 JAX 程序中随机密钥重用的**实验性**功能。它处于积极开发中,此处的 API 可能会发生变化。以下用法需要 JAX 版本 0.4.26 或更高版本。

可以使用 jax_debug_key_reuse 配置启用密钥重用检查。可以使用以下方法全局设置:

>>> jax.config.update('jax_debug_key_reuse', True)  

或者可以使用 jax.debug_key_reuse() 上下文管理器在本地启用。启用后,两次使用相同的密钥将导致 KeyReuseError

>>> import jax
>>> with jax.debug_key_reuse(True):
...   key = jax.random.key(0)
...   val1 = jax.random.normal(key)
...   val2 = jax.random.normal(key)  
Traceback (most recent call last):
 ...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

密钥重用检查器目前是实验性的,但将来我们可能会默认启用它。