jax.numpy.argpartition

jax.numpy.argpartition#

jax.numpy.argpartition(a, kth, axis=-1)[source]#

返回部分排序数组的索引。

JAX 实现的 numpy.argpartition()。JAX 版本在处理 NaN 项时与 NumPy 不同:设置了负位的 NaN 将排序到数组的开头。

参数::
  • a (ArrayLike) – 要进行部分排序的数组。

  • kth (int) – 用于部分排序数组的静态整数索引。

  • axis (int) – 用于部分排序数组的静态整数轴;默认值为 -1。

返回::

沿 axis 轴,对 a 进行分区,返回分区索引,其中 kth 为分区位置。 kth 之前的索引对应于小于 take(a, kth, axis) 的值,kth 之后的索引对应于大于 take(a, kth, axis) 的值。

返回类型:

数组

注意

JAX 版本要求 kth 参数为静态整数,而非通用数组。 此功能通过两次调用 jax.lax.top_k() 来实现。 如果您只需要访问输出的顶部或底部 k 个值,则直接调用 jax.lax.top_k() 可能效率更高。

另请参见

示例

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> idx = jnp.argpartition(x, kth)
>>> idx
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)

结果是索引序列,对输入进行部分排序。 所有 kth 之前的索引对应于小于枢纽值的值,所有 kth 之后的索引对应于大于枢纽值的值

>>> x_partitioned = x[idx]
>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [6 8 9 7 5]

请注意,在 smallest_valueslargest_values 中,返回的顺序是任意的,并且取决于实现。