jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None)[源代码]#

从分类分布中采样随机值。

参数:
  • key (ArrayLike) – 用作随机键的 PRNG 键。

  • logits (RealArray) – 要从中采样的分类分布的未归一化对数概率,使得 softmax(logits, axis) 给出相应的概率。

  • axis ( int ) – 对数概率(logits)所属的同一类别分布的轴。

  • shape (Shape | None | None) – 可选,表示结果形状的非负整数元组。 必须与 np.delete(logits.shape, axis) 广播兼容。 默认值 (None) 生成一个等于 np.delete(logits.shape, axis) 的结果形状。

返回:

如果 shape 不是 None,则返回一个 int 类型且形状由 shape 给定的随机数组,否则返回一个形状为 np.delete(logits.shape, axis) 的随机数组。

返回类型:

数组