jax.random.multivariate_normal#

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[source]#

根据给定的均值和协方差,采样多元正态分布的随机值。

返回的值遵循以下概率密度函数:

\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]

其中 \(k\) 是维度,\(\mu\) 是均值(由 mean 给出),\(\Sigma\) 是协方差矩阵(由 cov 给出)。

参数:
  • key (ArrayLike) – 用作随机密钥的 PRNG 密钥。

  • mean (RealArray) – 形状为 (..., n) 的均值向量。

  • cov (RealArray) – 形状为 (..., n, n) 的正定协方差矩阵。批处理形状 ... 必须与 mean 的批处理形状广播兼容。

  • shape (Shape | None | None) – 可选,一个非负整数元组,指定结果的批处理形状;即,结果形状的前缀,不包括最后一个轴。必须与 mean.shape[:-1]cov.shape[:-2] 广播兼容。默认值 (None) 通过广播 meancov 的批处理形状来生成结果的批处理形状。

  • dtype (DTypeLikeFloat | None | None) – 可选,返回值的浮点数据类型(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。

  • method (str) – 可选,用于计算 cov 因子的方法。必须是 ‘svd’、‘eigh’ 和 ‘cholesky’ 之一。默认为 ‘cholesky’。对于奇异协方差矩阵,请使用 ‘svd’ 或 ‘eigh’。

返回:

一个随机数组,其指定的数据类型和形状由 shape + mean.shape[-1:] 给出(如果 shape 不为 None),否则为 broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]

返回类型:

Array