jax.numpy.percentile#

jax.numpy.percentile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[源代码]#

计算指定轴上的数据百分位数。

numpy.percentile() 的 JAX 实现。

参数:
  • a (ArrayLike) – N维数组输入。

  • q (ArrayLike) – 标量或一维数组,指定所需的分位数。q 应包含介于 0100 之间的整数或浮点数值。

  • axis (int | tuple[int, ...] | None) – 可选的轴或轴元组,用于计算分位数。

  • out (None) – JAX 未实现;如果不是 None,则会报错。

  • overwrite_input (bool) – JAX 未实现;如果不是 False,则会报错。

  • method (str) – 指定要使用的插值方法。选项为 ["linear", "lower", "higher", "midpoint", "nearest"] 之一。默认值为 linear

  • keepdims (bool) – 如果为 True,则返回的数组将具有与输入相同的维度数。默认值为 False。

  • interpolation (str | DeprecatedArg) – method 参数的已弃用别名。如果使用,将导致 DeprecationWarning

返回:

一个包含沿指定轴的指定百分位数的数组。

返回类型:

数组

另请参阅

示例

计算一维数组的中位数和四分位数

>>> x = jnp.array([0, 1, 2, 3, 4, 5, 6])
>>> q = jnp.array([25, 50, 75])
>>> jnp.percentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)

使用最近邻而不是线性插值计算相同的百分位数

>>> jnp.percentile(x, q, method='nearest')
Array([1., 3., 4.], dtype=float32)