jax.Array.take# abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[源代码]# 从数组中提取元素。 完整文档请参考 jax.numpy.take()。 参数: self (Array) indices (ArrayLike) axis (int | None) out (None) mode (str | None) unique_indices (bool) indices_are_sorted (bool) fill_value (StaticScalar | None) 返回类型: Array