jax.numpy.argsort

内容

jax.numpy.argsort#

jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#

返回对数组进行排序的索引。

JAX 实现 numpy.argsort().

参数:
  • a (ArrayLike) – 要排序的数组

  • axis (int | None) – 整数轴,沿着该轴排序。默认值为 -1,即最后一个轴。如果为 None,则在排序之前将 a 展平。

  • stable (bool) – 布尔值,指定是否使用稳定排序。默认值为 True。

  • descending (bool) – 布尔值,指定是否以降序排序。默认值为 False。

  • kind (None) – 已弃用;改为使用 stable=True 或 stable=False 指定排序算法。

  • order (None) – JAX 不支持。

返回:

排序数组的索引数组。返回的数组的形状将为 a.shape(如果 axis 是整数)或 (a.size,)(如果 axis 为 None)。

返回类型:

数组

示例

简单的 1 维排序

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> indices = jnp.argsort(x)
>>> indices
Array([0, 5, 4, 1, 3, 2], dtype=int32)
>>> x[indices]
Array([1, 1, 2, 3, 4, 5], dtype=int32)

沿着数组的最后一个轴排序

>>> x = jnp.array([[2, 1, 3],
...                [6, 4, 3]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 0, 2],
       [2, 1, 0]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

另请参见