jax.random.split

内容

jax.random.split#

jax.random.split(key, num=2)[source]#

将一个 PRNG 键分割成 num 个新键,方法是在前面添加一个轴。

参数:
  • key (KeyArrayLike) – 一个 PRNG 键(来自 keysplitfold_in)。

  • num (int | tuple[int, ...]) – 可选,一个正整数(或整数元组),指示要生成的键的数量(或形状)。默认为 2。

返回值:

一个包含 num 个新 PRNG 键的类数组对象。

返回类型:

KeyArray