jax.Array.take

内容

jax.Array.take#

abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

从数组中获取元素。

请参考 jax.numpy.take() 获取完整文档。

参数:
  • self (Array)

  • indices (ArrayLike)

  • axis (int | None)

  • out (None)

  • mode (str | None)

  • unique_indices (bool)

  • indices_are_sorted (bool)

  • fill_value (StaticScalar | None)

返回值类型:

Array