jax.lax.gather#
- jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[源代码]#
收集运算符。
包装 XLA 的收集运算符。
gather()
是一个具有复杂语义的底层运算符,大多数 JAX 用户永远不需要直接调用它。相反,您应该更喜欢使用 Numpy 风格的索引和/或jax.numpy.ndarray.at()
,也许可以与jax.vmap()
结合使用。- 参数:
operand (类数组) – 从中提取切片的数组。
start_indices (类数组) – 提取切片的起始索引。
dimension_numbers (GatherDimensionNumbers) – 一个 lax.GatherDimensionNumbers 对象,描述 operand、start_indices 和输出的维度如何关联。
slice_sizes (Shape) – 每个切片的大小。必须是一个非负整数序列,其长度等于 ndim(operand)。
indices_are_sorted (bool) – 是否已知 indices 是已排序的。如果为 true,可能会提高某些后端上的性能。
unique_indices (bool) – 从
operand
中收集的元素是否保证不会彼此重叠。如果为True
,这可能会提高某些后端上的性能。JAX 不会检查此承诺:如果元素重叠,则行为是未定义的。mode (str | GatherScatterMode | None | None) – 如何处理超出范围的索引:当设置为
'clip'
时,索引会被裁剪,以便切片在边界内;当设置为'fill'
或'drop'
时,gather 会为受影响的切片返回一个充满fill_value
的切片。当设置为'promise_in_bounds'
时,超出范围的索引的行为是实现定义的。fill_value – 当 mode 为
'fill'
时,返回超出范围切片的填充值。否则将被忽略。对于非精确类型,默认为NaN
;对于有符号类型,默认为最大负值;对于无符号类型,默认为最大正值;对于布尔值,默认为True
。
- 返回:
一个包含 gather 输出的数组。
- 返回类型:
示例
如上所述,您基本上永远不应直接使用
gather()
,而是使用 NumPy 风格的索引表达式从数组中收集值。例如,下面是如何使用直接索引语义在特定索引处提取值,它会降低到 XLA 的 Gather 运算符
>>> import jax.numpy as jnp >>> x = jnp.array([10, 11, 12]) >>> indices = jnp.array([0, 1, 1, 2, 2, 2])
>>> x[indices] Array([10, 11, 11, 12, 12, 12], dtype=int32)
要控制诸如
indices_are_sorted
、unique_indices
、mode
和fill_value
之类的设置,可以使用jax.numpy.ndarray.at
语法>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds") Array([10, 11, 11, 12, 12, 12], dtype=int32)
相比之下,这里是直接使用
gather()
的等效函数调用,这不是典型用户需要执行的操作>>> from jax import lax >>> lax.gather(x, indices[:, None], slice_sizes=(1,), ... dimension_numbers=lax.GatherDimensionNumbers( ... offset_dims=(), ... collapsed_slice_dims=(0,), ... start_index_map=(0,)), ... indices_are_sorted=True, ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([10, 11, 11, 12, 12, 12], dtype=int32)