jax.lax.gather#
- jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]#
gather 运算符。
包装了 XLA 的 Gather 运算符.
gather 的语义很复杂,它的 API 可能会在将来发生改变。对于大多数用例,您应该优先使用 Numpy 风格的索引(例如,x[:, (1,4,7), …]),而不是直接使用 gather。
- 参数:
operand (ArrayLike) – 应该从中获取切片的数组
start_indices (ArrayLike) – 应该从中获取切片的索引
dimension_numbers (GatherDimensionNumbers) – 一个 lax.GatherDimensionNumbers 对象,描述了 operand、start_indices 和输出的维度之间的关系。
slice_sizes (Shape) – 每个切片的大小。必须是一个非负整数序列,长度等于 ndim(operand)。
indices_are_sorted (bool) – 是否知道 indices 已排序。如果为真,可能会在某些后端上提高性能。
unique_indices (bool) – 从
operand
中收集的元素是否保证不会相互重叠。如果为True
,这可能会提高某些后端的性能。JAX 不会检查此承诺:如果元素重叠,则行为未定义。mode (str | GatherScatterMode | None | None) – 如何处理越界索引:当设置为
'clip'
时,索引会被截断以使切片在范围内,当设置为'fill'
或'drop'
时,gather 会返回一个充满fill_value
的切片。当设置为'promise_in_bounds'
时,越界索引的行为是实现定义的。fill_value – 当 mode 为
'fill'
时,用于返回越界切片的填充值。否则会被忽略。默认情况下,对于不精确类型为NaN
,对于有符号类型为最大负值,对于无符号类型为最大正值,对于布尔值为True
。
- 返回:
包含 gather 输出的数组。
- 返回类型: