jax.Array.take#

abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[源代码]#

从数组中提取元素。

完整文档请参考 jax.numpy.take()

参数:
  • self (Array)

  • indices (ArrayLike)

  • axis (int | None)

  • out (None)

  • mode (str | None)

  • unique_indices (bool)

  • indices_are_sorted (bool)

  • fill_value (StaticScalar | None)

返回类型:

Array