jax.lax.top_k

内容

jax.lax.top_k#

jax.lax.top_k(operand, k)[source]#

返回 operand 最后一个轴上的前 k 个值及其索引。

参数:
  • operand (ArrayLike) – 非复数类型的 N 维数组。

  • k (int) – 指定前 n 个条目的整数。

返回值:

一个元组 (values, indices),其中

  • values 是一个数组,包含最后一个轴上的前 k 个值。

  • indices 是一个数组,包含对应于值的索引。

返回类型:

tuple[Array, Array]

示例

在数组中找到最大的三个值及其索引

>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10.,  9.,  6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)