jax.ops.segment_min#
- jax.ops.segment_min(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[source]#
计算数组中段内的最小值。
类似于 TensorFlow 的 segment_min
- 参数::
data (ArrayLike) – 要约简的值的数组。
segment_ids (ArrayLike) – 一个整数类型的数组,指示 data (沿其前导轴) 要约简的段。值可以重复,并且不需要排序。超出范围 [0, num_segments) 的值将被丢弃,不会影响结果。
num_segments (int | None | None) – 可选,一个非负整数,表示段数。默认情况下,该值被设置为支持所有索引的最小段数
segment_ids
,计算方式为max(segment_ids) + 1
。由于 num_segments 决定输出的大小,因此在 JIT 编译函数中使用segment_min
必须提供静态值。indices_are_sorted (bool) – 是否已知
segment_ids
已排序。unique_indices (bool) – 是否已知 segment_ids 不包含重复项。
bucket_size (int | None | None) – 将索引分组的桶大小。对每个桶分别执行
segment_min
。默认情况下,None
表示不进行分桶。mode (lax.GatherScatterMode | None | None) – 一个
jax.lax.GatherScatterMode
值,描述如何处理越界索引。默认情况下,范围外的值 [0, num_segments) 被丢弃,不会计入总和。
- 返回值:
形状为
(num_segments,) + data.shape[1:]
的数组,表示段最小值。- 返回类型:
示例
简单的 1D 段最小值
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_min(data, segment_ids) Array([0, 2, 4], dtype=int32)
使用 JIT 需要静态的 num_segments
>>> from jax import jit >>> jit(segment_min, static_argnums=2)(data, segment_ids, 3) Array([0, 2, 4], dtype=int32)