jax.numpy.nanquantile

内容

jax.numpy.nanquantile#

jax.numpy.nanquantile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[source]#

计算数据沿指定轴的百分位数,忽略 NaN。

JAX 实现 numpy.nanquantile().

参数:
  • a (ArrayLike) – N 维数组输入。

  • q (ArrayLike) – 标量或一维数组,指定所需分位数。 q 应包含介于 0.01.0 之间的浮点值。

  • axis (int | tuple[int, ...] | None) – 可选的轴或轴元组,沿着这些轴计算分位数。

  • out (None) – JAX 未实现;如果非 None,则会出错。

  • overwrite_input (bool) – JAX 未实现;如果非 False,则会出错。

  • method (str) – 指定要使用的插值方法。选项之一是 ["linear", "lower", "higher", "midpoint", "nearest"]。默认值为 linear

  • keepdims (bool) – 如果为 True,则返回的数组将与输入数组具有相同的维数。默认值为 False。

  • interpolation (DeprecatedArg | str) – method 参数的弃用别名。如果使用,会导致 DeprecationWarning

返回:

包含沿指定轴的指定分位数的数组。

返回类型:

数组

另请参阅

示例

计算一维数组的中位数和四分位数

>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([0.25, 0.5, 0.75])

由于存在 NaN 值,jax.numpy.quantile() 返回全 NaN,而 nanquantile() 忽略 NaN

>>> jnp.quantile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanquantile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)