jax.lax.top_k#

jax.lax.top_k(operand, k)[来源]#

返回沿 operand 的最后一个轴的前 k 个值及其索引。

参数
  • operand (ArrayLike) – 非复数类型的 N 维数组。

  • k (int) – 指定顶部条目数量的整数。

返回

一个元组 (values, indices),其中

  • values 是一个数组,包含沿最后一个轴的最大的 k 个值。

  • indices 是一个数组,包含与值相对应的索引。

返回类型:

tuple[Array, Array]

示例

查找数组中最大的三个值及其索引

>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10.,  9.,  6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)