jax.experimental.key_reuse
模块#
实验性密钥重用检查#
此模块包含用于检测 JAX 程序中随机密钥重用的**实验性**功能。它处于积极开发中,此处的 API 可能会发生变化。以下用法需要 JAX 版本 0.4.26 或更高版本。
可以使用 jax_debug_key_reuse
配置启用密钥重用检查。可以使用以下方法全局设置:
>>> jax.config.update('jax_debug_key_reuse', True)
或者可以使用 jax.debug_key_reuse()
上下文管理器在本地启用。启用后,两次使用相同的密钥将导致 KeyReuseError
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key)
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
密钥重用检查器目前是实验性的,但将来我们可能会默认启用它。