jax.extend.random.unsafe_rbg_prng_impl#

jax.extend.random.unsafe_rbg_prng_impl = ((4,), <function _rbg_seed>, <function _unsafe_rbg_split>, <function _rbg_random_bits>, <function _unsafe_rbg_fold_in>, 'unsafe_rbg', 'urbg')#

指定 PRNG 键的形状和操作。

PRNG(伪随机数生成器)的实现由一个键类型 K 和一组操作这些键的函数决定。键类型 K 是一个数组类型,其元素类型为 uint32,形状由 key_shape 指定。每个操作的类型签名如下:

seed :: int[] -> K
fold_in :: K -> int[] -> K
split[shape] :: K -> K[*shape]
random_bits[shape, bit_width] :: K -> uint<bit_width>[*shape]

通过 PRNGKeyArray 类,PRNG 的实现可以适配到一个类似于键 K 的数组对象。该类应该通过 random_seed 函数创建。