jax.lax.approx_min_k

jax.lax.approx_min_k#

jax.lax.approx_min_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[source]#

以近似方式返回 operand 中最小 k 个值及其索引。

有关算法详细信息,请参阅 https://arxiv.org/abs/2206.14286

参数:
  • operand (Array) – 要搜索最小 k 的数组。必须是浮点数类型。

  • k (int) – 指定最小 k 的数量。

  • reduction_dimension (int) – 用于搜索的整数维度。默认值为 -1。

  • recall_target (float) – 近似值的召回目标。

  • reduction_input_size_override (int) – 当设置为正值时,它会覆盖由 operand[reduction_dim] 确定的大小,用于评估召回率。此选项在给定操作数仅是 SPMD 或分布式管道中整体计算的一个子集时很有用,在这种情况下,真实输入大小不能由 operand 形状推迟。

  • aggregate_to_topk (bool) – 当为 True 时,会将近似结果聚合到排序后的前 k 个。当为 False 时,返回未排序的近似结果。在这种情况下,近似结果的数量由实现定义,并且大于或等于指定的 k

返回值:

两个数组的元组。数组是输入 operand 沿 reduction_dimension 的最小的 k 个值和对应的索引。数组的维度与输入 operand 相同,除了 reduction_dimension:当 aggregate_to_topk 为 True 时,缩减维度为 k;否则,它大于等于 k,其中大小由实现定义。

返回类型:

tuple[Array, Array]

我们鼓励用户使用 jit 包装 approx_min_k。请参阅以下关于平方 l2 距离的最近邻搜索的示例

>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
...   dists = half_db_norms - jax.lax.dot(qy, db.transpose())
...   return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)

在上面的示例中,我们计算了 db^2/2 - dot(qy, db^T) 而不是 qy^2 - 2 dot(qy, db^T) + db^2,以提高性能。前者使用的算术运算更少,并且产生相同的邻居集。