jax.numpy.partition#
- jax.numpy.partition(a, kth, axis=-1)[source]#
返回数组的部分排序副本。
JAX 实现
numpy.partition()
。JAX 版本在 NaN 条目的处理方面与 NumPy 不同:具有负位设置的 NaN 将排序到数组的开头。- 参数::
- 返回::
在
axis
轴上,按照kth
位置的值对a
进行分区复制。kth
位置之前的元素值小于take(a, kth, axis)
,而kth
位置之后的元素值大于take(a, kth, axis)
。- 返回值类型:
注意
JAX 版本要求
kth
参数是一个静态整数,而不是一个通用的数组。这是通过两次调用jax.lax.top_k()
实现的。如果您只访问输出的前 k 个或后 k 个值,直接调用jax.lax.top_k()
可能更有效。参见
jax.numpy.sort()
: 全排序jax.numpy.argpartition()
: 间接部分排序jax.lax.top_k()
: 直接查找前 k 个条目jax.lax.approx_max_k()
: 计算近似前 k 个条目jax.lax.approx_min_k()
: 计算近似后 k 个条目
示例
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
结果是对输入的局部排序副本。所有
kth
位置之前的元素值小于枢轴值,所有kth
位置之后的元素值大于枢轴值>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
注意,在
smallest_values
和largest_values
中,返回的顺序是任意的,并且取决于实现。