jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

在已排序数组中执行二分查找。

JAX 对 numpy.searchsorted() 的实现。

这将返回已排序数组 a 中可以插入 v 中的值以保持其排序顺序的索引。

参数:
  • a (ArrayLike) – 一维数组,除非指定了 sorter,否则假定为已排序的顺序。

  • v (ArrayLike) – N 维查询值数组

  • side (str) – 'left' (默认) 或 'right';指定在出现并列的情况下,插入索引是在左侧还是右侧。

  • sorter (ArrayLike | None) – 可选的索引数组,指定 a 的排序顺序。如果指定,则算法假定 a[sorter] 是已排序的。

  • method (str) – 'scan' (默认)、'scan_unrolled''sort''compare_all' 之一。请参阅下面的注意

返回:

形状为 v.shape 的插入索引数组。

返回类型:

数组

注意

method 参数控制用于计算插入索引的算法。

  • 'scan'(默认)在 CPU 上往往性能更高,尤其是在 a 非常大时。

  • 'scan_unrolled' 在 GPU 上性能更高,但会增加编译时间。

  • 'sort' 在像 GPU 和 TPU 这样的加速器后端上通常性能更高,尤其是在 v 非常大时。

  • a 非常小时,'compare_all' 往往性能最高。

示例

搜索单个值

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜索一批值

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

可选地,可以使用 sorter 参数来查找通过 jax.numpy.argsort() 排序的数组的插入索引

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

结果等同于传递已排序的数组

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)