jax.lax.top_k#
- jax.lax.top_k(operand, k)[来源]#
返回沿
operand
的最后一个轴的前k
个值及其索引。- 参数:
operand (ArrayLike) – 非复数类型的 N 维数组。
k (int) – 指定顶部条目数量的整数。
- 返回:
一个元组
(values, indices)
,其中values
是一个数组,包含沿最后一个轴的最大的 k 个值。indices
是一个数组,包含与值相对应的索引。
- 返回类型:
示例
查找数组中最大的三个值及其索引
>>> x = jnp.array([9., 3., 6., 4., 10.]) >>> values, indices = jax.lax.top_k(x, 3) >>> values Array([10., 9., 6.], dtype=float32) >>> indices Array([4, 0, 2], dtype=int32)