jax.numpy.searchsorted

jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

在已排序的数组中执行二分查找。

JAX 实现 numpy.searchsorted()

这将返回已排序数组 a 中的索引,在这些索引处可以插入 v 中的值以保持其排序顺序。

参数:
  • a (ArrayLike) – 一维数组,除非指定了 sorter,否则假定为已排序。

  • v (ArrayLike) – 查询值的 N 维数组

  • side (str) – 'left'(默认)或'right';在出现平局的情况下,指定插入索引是在左侧还是右侧。

  • sorter (ArrayLike | None) – a 的排序顺序的可选索引数组。如果指定,则算法假定a[sorter] 已排序。

  • method (str) – 'scan'(默认)、'scan_unrolled''sort''compare_all' 之一。请参见下面的“注意”。

返回:

形状为v.shape 的插入索引数组。

返回类型:

数组

注意

method 参数控制用于计算插入索引的算法。

  • 'scan'(默认)在 CPU 上的性能往往更好,尤其是在a 非常大的情况下。

  • 'scan_unrolled' 在 GPU 上的性能更好,但需要额外的编译时间。

  • 'sort' 在 GPU 和 TPU 等加速器后端上的性能通常更好,尤其是在v 非常大的情况下。

  • 'compare_all'a 非常小的情况下,性能往往最好。

示例

搜索单个值

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜索一批值

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

可选地,可以使用sorter 参数查找通过jax.numpy.argsort() 排序的数组中的插入索引。

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

结果等同于传递排序后的数组

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)