jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[source]#
返回部分排序数组的索引。
JAX 实现的
numpy.argpartition()
。JAX 版本在处理 NaN 项时与 NumPy 不同:设置了负位的 NaN 将排序到数组的开头。- 参数::
- 返回::
沿
axis
轴,对a
进行分区,返回分区索引,其中kth
为分区位置。kth
之前的索引对应于小于take(a, kth, axis)
的值,kth
之后的索引对应于大于take(a, kth, axis)
的值。- 返回类型:
注意
JAX 版本要求
kth
参数为静态整数,而非通用数组。 此功能通过两次调用jax.lax.top_k()
来实现。 如果您只需要访问输出的顶部或底部 k 个值,则直接调用jax.lax.top_k()
可能效率更高。另请参见
jax.numpy.partition()
: 直接部分排序jax.numpy.argsort()
: 完全间接排序jax.lax.top_k()
: 直接查找顶部 k 个条目jax.lax.approx_max_k()
: 计算近似顶部 k 个条目jax.lax.approx_min_k()
: 计算近似底部 k 个条目
示例
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
结果是索引序列,对输入进行部分排序。 所有
kth
之前的索引对应于小于枢纽值的值,所有kth
之后的索引对应于大于枢纽值的值>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
请注意,在
smallest_values
和largest_values
中,返回的顺序是任意的,并且取决于实现。