jax.random.choice

内容

jax.random.choice#

jax.random.choice(key, a, shape=(), replace=True, p=None, axis=0)[source]#

从给定数组中生成随机样本。

警告

如果 p 的非零元素数量少于请求的样本数量(如 shape 中指定),并且 replace=False,则此函数的输出未定义。请确保使用适当的输入。

参数:
  • key (KeyArrayLike) – 用作随机密钥的 PRNG 密钥。

  • a (int | ArrayLike) – 数组或整数。如果为 ndarray,则从其元素生成随机样本。如果为整数,则生成随机样本,就像 a 为 arange(a) 一样。

  • shape (Shape) – 整数元组,可选。输出形状。如果给定的形状为,例如,(m, n),则绘制 m * n 个样本。默认为 (),在这种情况下返回单个值。

  • replace (bool) – 布尔值。样本是否进行放回抽样。默认为 True。

  • p (RealArray | None | None) – 一维类数组,与 a 中每个条目关联的概率。如果未给出,则样本假设 a 中所有条目上的均匀分布。

  • axis (int) – 整数,可选。执行选择的轴。默认值 0 按行选择。

返回:

一个形状为 shape 的数组,包含来自 a 的样本。

返回类型:

数组