jax.lax.approx_min_k#
- jax.lax.approx_min_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[源代码]#
以近似方式返回
operand
的最小k
值及其索引。有关算法详细信息,请参阅 https://arxiv.org/abs/2206.14286。
- 参数:
operand (Array) – 用于搜索最小 k 值的数组。必须是浮点数类型。
k (int) – 指定最小 k 值的数量。
reduction_dimension (int) – 搜索的整数维度。默认值:-1。
recall_target (float) – 近似值的召回目标。
reduction_input_size_override (int) – 当设置为正值时,它会覆盖由
operand[reduction_dim]
确定的用于评估召回率的大小。当给定的操作数只是 SPMD 或分布式管道中整体计算的一个子集时,此选项很有用,在这些情况下,真实的输入大小不能通过operand
形状来推迟确定。aggregate_to_topk (bool) – 当为 true 时,将近似结果聚合到排序后的前 k 个。当为 false 时,返回未排序的近似结果。在这种情况下,近似结果的数量是实现定义的,并且大于或等于指定的
k
。
- 返回:
包含两个数组的元组。这些数组是输入
operand
沿reduction_dimension
的最小k
个值和相应的索引。这些数组的维度与输入operand
相同,除了reduction_dimension
:当aggregate_to_topk
为 true 时,缩减维度是k
;否则,它大于等于k
,其中大小是实现定义的。- 返回类型:
我们鼓励用户使用 jit 包装
approx_min_k
。 请参阅以下示例,了解如何使用平方 l2 距离进行最近邻搜索>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95): ... dists = half_db_norms - jax.lax.dot(qy, db.transpose()) ... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2 >>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
在上面的示例中,出于性能原因,我们计算
db^2/2 - dot(qy, db^T)
而不是qy^2 - 2 dot(qy, db^T) + db^2
。前者使用较少的算术运算,并产生相同的邻居集合。