jax.numpy.searchsorted#
- jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#
在已排序数组中执行二分查找。
JAX 对
numpy.searchsorted()
的实现。这将返回已排序数组
a
中可以插入v
中的值以保持其排序顺序的索引。- 参数:
a (ArrayLike) – 一维数组,除非指定了
sorter
,否则假定为已排序的顺序。v (ArrayLike) – N 维查询值数组
side (str) –
'left'
(默认) 或'right'
;指定在出现并列的情况下,插入索引是在左侧还是右侧。sorter (ArrayLike | None) – 可选的索引数组,指定
a
的排序顺序。如果指定,则算法假定a[sorter]
是已排序的。method (str) –
'scan'
(默认)、'scan_unrolled'
、'sort'
或'compare_all'
之一。请参阅下面的注意。
- 返回:
形状为
v.shape
的插入索引数组。- 返回类型:
注意
method
参数控制用于计算插入索引的算法。'scan'
(默认)在 CPU 上往往性能更高,尤其是在a
非常大时。'scan_unrolled'
在 GPU 上性能更高,但会增加编译时间。'sort'
在像 GPU 和 TPU 这样的加速器后端上通常性能更高,尤其是在v
非常大时。当
a
非常小时,'compare_all'
往往性能最高。
示例
搜索单个值
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
搜索一批值
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
可选地,可以使用
sorter
参数来查找通过jax.numpy.argsort()
排序的数组的插入索引>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
结果等同于传递已排序的数组
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)