jax.random.wrap_key_data#

jax.random.wrap_key_data(key_bits_array, *, impl=None)[源代码]#

将一个密钥数据位数组封装成一个 PRNG 密钥数组。

参数:
  • key_bits_array (Array) – 一个 uint32 数组,其尾部形状对应于由 impl 指定的 PRNG 实现的密钥形状。

  • impl (PRNGSpecDesc | None | None) – 可选,指定一个 PRNG 实现,如 random.key 中所示。

返回:

一个 PRNG 密钥数组,其 dtype 是 jax.dtypes.prng_key 的子类型

,对应于 impl,并且其形状等于 key_bits_array.shape 的前导形状,直至密钥位维度。