jax.Array.argpartition

jax.Array.argpartition#

abstract Array.argpartition(kth, axis=-1)[source]#

返回部分排序数组的索引。

有关完整文档,请参阅 jax.numpy.argpartition()

参数:
返回类型:

Array