jax.random.categorical#
- jax.random.categorical(key, logits, axis=-1, shape=None)[source]#
从分类分布中采样随机值。
- 参数::
key (KeyArrayLike) – 用作随机键的 PRNG 键。
logits (RealArray) – 要从中采样的分类分布的未归一化对数概率,因此 softmax(logits, axis) 给出了相应的概率。
axis (int) – 对数属于同一个分类分布的轴。
shape (Shape | None | None) – 可选,表示结果形状的非负整数元组。必须与
np.delete(logits.shape, axis)
广播兼容。默认值 (None) 生成等于np.delete(logits.shape, axis)
的结果形状。
- 返回值::
如果
shape
不为 None,则为具有 int 数据类型和shape
给出的形状的随机数组,否则为np.delete(logits.shape, axis)
。- 返回类型::