jax.ops 模块#

在 JAX 0.2.22 版本中已弃用的函数 jax.ops.index_update, jax.ops.index_add 等已被移除。请改用 JAX 数组的 jax.numpy.ndarray.at 属性。

段归约运算符#

segment_max(data, segment_ids[, ...])

计算数组各个段内的最大值。

segment_min(data, segment_ids[, ...])

计算数组各个段内的最小值。

segment_prod(data, segment_ids[, ...])

计算数组各个段内的乘积。

segment_sum(data, segment_ids[, ...])

计算数组各个段内的总和。