jax.ops.segment_max

目录

jax.ops.segment_max#

jax.ops.segment_max(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#

计算数组中段内的最大值。

类似于 TensorFlow 的 segment_max

参数:
  • data (ArrayLike) – 要进行缩减的值所在的数组。

  • segment_ids (ArrayLike) – 一个整型 dtype 的数组,指示要缩减的 data(沿着其前导轴)的段。值可以重复,不需要排序。范围之外的值 [0, num_segments) 将被丢弃,不会对结果做出贡献。

  • num_segments (int | None | None) – 可选,一个非负整型值,指示段的数量。默认值为支持 segment_ids 中所有索引的最小段数,计算为 max(segment_ids) + 1。由于 num_segments 决定输出的大小,因此必须提供一个静态值才能在 JIT 编译函数中使用 segment_max

  • indices_are_sorted (bool) – 是否已知 segment_ids 已排序。

  • unique_indices (布尔值) – 是否 segment_ids 已知不包含重复项。

  • bucket_size (整数 | | ) – 用于将索引分组到桶中的桶大小。 segment_max 在每个桶上单独执行。 默认值 None 表示不进行分桶。

  • mode (lax.GatherScatterMode | | ) – 一个 jax.lax.GatherScatterMode 值,描述如何处理超出范围的索引。 默认情况下,超出范围 [0, num_segments) 的值会被丢弃,不会对求和结果产生贡献。

返回值:

一个形状为 (num_segments,) + data.shape[1:] 的数组,表示每个段的最大值。

返回类型:

数组

示例

简单的 1D 段最大值

>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids)
Array([1, 3, 5], dtype=int32)

使用 JIT 需要静态的 num_segments

>>> from jax import jit
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
Array([1, 3, 5], dtype=int32)