jax.ops.segment_sum

内容

jax.ops.segment_sum#

jax.ops.segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#

计算数组中段内的总和。

类似于 TensorFlow 的 segment_sum

参数:
  • data (ArrayLike) – 要求和的值的数组。

  • segment_ids (ArrayLike) – 整数类型的数组,指示要求和的 data(沿其前导轴)的段。值可以重复,不需要排序。

  • num_segments (int | None | None) – 可选,一个非负整数值,表示分段的数量。默认设置为支持segment_ids中所有索引的最小分段数,计算方式为max(segment_ids) + 1。由于num_segments决定了输出的大小,因此必须提供一个静态值才能在 JIT 编译的函数中使用segment_sum

  • indices_are_sorted (bool) – segment_ids是否已知为排序。

  • unique_indices (bool) – segment_ids是否已知不包含重复项。

  • bucket_size (int | None | None) – 用于将索引分组的桶的大小。segment_sum分别对每个桶执行,以提高加法的数值稳定性。默认值None表示不进行分桶。

  • mode (lax.GatherScatterMode | None | None) – 一个jax.lax.GatherScatterMode值,描述如何处理越界索引。默认情况下,超出范围 [0, num_segments) 的值将被丢弃,并且不会对总和做出贡献。

返回:

一个形状为(num_segments,) + data.shape[1:]的数组,表示分段的总和。

返回类型:

数组

示例

简单的 1D 分段求和

>>> data = jnp.arange(5)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
>>> segment_sum(data, segment_ids)
Array([1, 5, 4], dtype=int32)

使用 JIT 需要静态的num_segments

>>> from jax import jit
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
Array([1, 5, 4], dtype=int32)