jax.lax.index_take# jax.lax.index_take(src, idxs, axes)[source]# 参数: src (Array) idxs (Array) axes (Sequence[int]) 返回类型: Array