jax.Array.take# abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]# 从数组中获取元素。 请参考 jax.numpy.take() 获取完整文档。 参数: self (Array) indices (ArrayLike) axis (int | None) out (None) mode (str | None) unique_indices (bool) indices_are_sorted (bool) fill_value (StaticScalar | None) 返回值类型: Array