jax.numpy.lexsort#

jax.numpy.lexsort(keys, axis=-1)[源代码]#

按字典顺序对键序列进行排序。

numpy.lexsort() 的 JAX 实现。

参数:
  • keys (Array | np.ndarray | Sequence[ArrayLike]) – 要排序的数组序列;所有数组必须具有相同的形状。序列中的最后一个键用作主键。

  • axis (int) – 要排序的轴(默认值:-1)。

返回:

一个形状为 keys[0].shape 的整数数组,给出按字典顺序排序的条目的索引。

返回类型:

Array

另请参阅

示例

lexsort() 使用单个键等同于 argsort()

>>> key1 = jnp.array([4, 2, 3, 2, 5])
>>> jnp.lexsort([key1])
Array([1, 3, 2, 0, 4], dtype=int32)
>>> jnp.argsort(key1)
Array([1, 3, 2, 0, 4], dtype=int32)

使用多个键时,lexsort() 使用最后一个键作为主键

>>> key2 = jnp.array([2, 1, 1, 2, 2])
>>> jnp.lexsort([key1, key2])
Array([1, 2, 3, 0, 4], dtype=int32)

当打印排序后的键时,索引的含义会更加清晰

>>> indices = jnp.lexsort([key1, key2])
>>> print(f"{key1[indices]}\n{key2[indices]}")
[2 3 2 4 5]
[1 1 2 2 2]

请注意,key2 的元素按顺序排列,并且在重复值的序列中,相应的 `key1 元素按顺序排列。

对于多维输入,lexsort() 默认沿最后一个轴排序

>>> key1 = jnp.array([[2, 4, 2, 3],
...                   [3, 1, 2, 2]])
>>> key2 = jnp.array([[1, 2, 1, 3],
...                   [2, 1, 2, 1]])
>>> jnp.lexsort([key1, key2])
Array([[0, 2, 1, 3],
       [1, 3, 2, 0]], dtype=int32)

可以使用 axis 关键字选择不同的排序轴;这里我们沿着前导轴排序

>>> jnp.lexsort([key1, key2], axis=0)
Array([[0, 1, 0, 1],
       [1, 0, 1, 0]], dtype=int32)