jax.numpy.sort#
- jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[源代码]#
返回数组的排序副本。
JAX 对
numpy.sort()
的实现。- 参数:
- 返回:
形状为
a.shape
的已排序数组(如果axis
是整数),或者形状为(a.size,)
的已排序数组(如果axis
是 None)。- 返回类型:
示例
简单的一维排序
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
沿数组的最后一个轴排序
>>> x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) >>> jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
另请参阅
jax.numpy.argsort()
:返回已排序值的索引。jax.numpy.lexsort()
:多个数组的字典式排序。jax.lax.sort()
:包装 XLA Sort 运算符的底层函数。