jax.lax.sort_key_val# jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]# 沿着 dimension 对 keys 进行排序,并将相同的排列应用于 values。 参数: keys (Array) values (ArrayLike) dimension (int) is_stable (bool) 返回类型: tuple[Array, Array]