jax.scipy.stats.rankdata

内容

jax.scipy.stats.rankdata#

jax.scipy.stats.rankdata(a, method='average', *, axis=None, nan_policy='propagate')[source]#

沿数组轴计算数据的秩。

JAX 实现的 scipy.stats.rankdata().

秩从 1 开始,并且 *method* 参数控制如何处理并列。

参数:
  • a (ArrayLike) – 类数组

  • method (str) – str,默认值 =“average”。支持的方法为 ("average", "min", "max", "dense", "ordinal") 有关详细信息,请参阅 scipy.stats.rankdata() 文档。

  • axis (int | None) – 可选整数。如果未指定,则输入数组将被展平。

  • nan_policy (str) – str,JAX 的实现仅支持 "propagate"

返回值:

沿指定轴的秩数组。

返回类型:

数组

示例

>>> x = jnp.array([10, 30, 20])
>>> rankdata(x)
Array([1., 3., 2.], dtype=float32)
>>> x = jnp.array([1, 3, 2, 3])
>>> rankdata(x)
Array([1. , 3.5, 2. , 3.5], dtype=float32)