jax.scipy.stats.mode

内容

jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#

计算数组沿轴的众数(最常见的值)。

JAX 实现 scipy.stats.mode().

参数:
  • a (ArrayLike) – 类数组

  • axis (int | None) – int,默认值为 0。计算众数的轴。

  • nan_policy (str) – str。JAX 只支持 "propagate"

  • keepdims (bool) – bool,默认值为 False。如果为 True,则保留大小为 1 的缩减轴。

返回值:

一个包含数组的元组,(mode, count)mode 是众数值数组,count 是每个值在输入数组中出现的次数。

返回值类型:

ModeResult

示例

>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))

对于多维数组,jax.scipy.stats.mode 计算 mode 和相应的 count 沿 axis=0

>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))

如果 axis=1modecount 将沿 axis 1 计算。

>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))

默认情况下,jax.scipy.stats.mode 会减少结果的维度。为了保持与输入数组相同的维度,必须将参数 keepdims 设置为 True

>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
       [3],
       [2]], dtype=int32), Array([[3],
       [3],
       [3]], dtype=int32))