jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#

计算数组沿轴的众数(最常见的值)。

JAX 对 scipy.stats.mode() 的实现。

参数:
  • a (ArrayLike) – 类数组

  • axis (int | None) – int, 默认值=0。计算众数的轴。

  • nan_policy (str) – str。JAX 仅支持 "propagate"

  • keepdims (bool) – bool, 默认值=False。如果为 true,则结果中保留缩减的轴,大小为 1。

返回值:

一个数组元组,(mode, count)mode 是众数值的数组,count 是每个值在输入数组中出现的次数。

返回类型:

ModeResult

示例

>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))

对于多维数组,jax.scipy.stats.mode 计算沿 axis=0mode 和相应的 count

>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))

如果 axis=1,则将沿 axis 1 计算 modecount

>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))

默认情况下,jax.scipy.stats.mode 会减小结果的维度。要保持维度与输入数组相同,必须将参数 keepdims 设置为 True

>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
       [3],
       [2]], dtype=int32), Array([[3],
       [3],
       [3]], dtype=int32))