jax.scipy.stats.sem#
- jax.scipy.stats.sem(a, axis=0, ddof=1, nan_policy='propagate', *, keepdims=False)[源代码]#
计算均值的标准误差。
JAX 实现的
scipy.stats.sem()
。- 参数:
a (类数组) – 类数组
axis (int | None) – 可选整数。如果未指定,则输入数组将被展平。
ddof (int) – 整数,默认值=1。SEM 计算中的自由度。
nan_policy (str) – 字符串,默认值=”propagate”。JAX 仅支持 “propagate” 和 “omit”。
keepdims (bool) – 布尔值,默认值=False。如果为 true,则缩减的轴将保留在结果中,大小为 1。
- 返回:
数组
- 返回类型:
示例
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x) Array(0.41, dtype=float32)
对于多维数组,
sem
沿axis=0
计算均值的标准误差>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1) Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
如果
axis=1
,则均值的标准误差将沿axis 1
计算。>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=1) Array([0.33, 0.4 , 0.31], dtype=float32)
如果
axis=None
,则均值的标准误差将沿所有轴计算。>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=None) Array(0.2, dtype=float32)
默认情况下,
sem
会降低结果的维度。要保持维度与输入数组相同,必须将参数keepdims
设置为True
。>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x1, axis=1, keepdims=True) Array([[0.33], [0.4 ], [0.31]], dtype=float32)
由于默认情况下
nan_policy='propagate'
,sem
会在结果中传播nan
值。>>> nan = jnp.nan >>> x2 = jnp.array([[1, 2, 3, nan, 4, 2], ... [4, 5, 4, 3, nan, 1], ... [7, nan, 8, 7, 9, nan]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2) Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
如果
nan_policy='omit'
,sem
将省略nan
值,并计算指定轴上剩余值的误差。>>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2, nan_policy='omit') Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)