jax.lax.sort_key_val

jax.lax.sort_key_val#

jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]#

沿着 dimensionkeys 进行排序,并将相同的排列应用于 values

参数:
  • keys (Array)

  • values (ArrayLike)

  • dimension (int)

  • is_stable (bool)

返回类型:

tuple[Array, Array]