jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

在已排序的数组中执行二分查找。

numpy.searchsorted() 的 JAX 实现。

这将返回已排序数组 a 中, 值 v 可以插入以保持其排序顺序的索引。

参数:
  • a (ArrayLike) – 一维数组,假设已排序,除非指定了 sorter

  • v (ArrayLike) – N 维查询值数组

  • side (str) – 'left' (默认) 或 'right';指定在存在并列的情况下,插入索引是在左侧还是右侧。

  • sorter (ArrayLike | None) – 可选的索引数组,指定 a 的排序顺序。如果指定,则算法假设 a[sorter] 已排序。

  • method (str) – 'scan' (默认), 'scan_unrolled', 'sort''compare_all' 之一。请参阅下面的 *注意*。

返回:

形状为 v.shape 的插入索引数组。

返回类型:

数组

注意

method 参数控制用于计算插入索引的算法。

  • 'scan' (默认)在 CPU 上往往性能更高,尤其是在 a 非常大的情况下。

  • 'scan_unrolled' 在 GPU 上性能更高,但会增加编译时间。

  • 'sort' 通常在 GPU 和 TPU 等加速器后端上性能更高,尤其是在 v 非常大的情况下。

  • a 非常小时,'compare_all' 往往性能最高。

示例

搜索单个值

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜索一批值

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

(可选)可以使用 sorter 参数查找通过 jax.numpy.argsort() 排序的数组的插入索引

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

结果等效于传递已排序的数组

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)