jax.numpy.partition

内容

jax.numpy.partition#

jax.numpy.partition(a, kth, axis=-1)[source]#

返回数组的部分排序副本。

JAX 实现 numpy.partition()。JAX 版本在 NaN 条目的处理方面与 NumPy 不同:具有负位设置的 NaN 将排序到数组的开头。

参数::
  • a (ArrayLike) – 要分区的数组。

  • kth (int) – 关于要对数组进行分区的静态整数索引。

  • axis (int) – 要对数组进行分区的静态整数轴;默认值为 -1。

返回::

axis 轴上,按照 kth 位置的值对 a 进行分区复制。 kth 位置之前的元素值小于 take(a, kth, axis),而 kth 位置之后的元素值大于 take(a, kth, axis)

返回值类型:

数组

注意

JAX 版本要求 kth 参数是一个静态整数,而不是一个通用的数组。这是通过两次调用 jax.lax.top_k() 实现的。如果您只访问输出的前 k 个或后 k 个值,直接调用 jax.lax.top_k() 可能更有效。

参见

示例

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

结果是对输入的局部排序副本。所有 kth 位置之前的元素值小于枢轴值,所有 kth 位置之后的元素值大于枢轴值

>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]

注意,在 smallest_valueslargest_values 中,返回的顺序是任意的,并且取决于实现。