jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[源代码]#

计算加权平均值。

numpy.average() 的 JAX 实现。

参数:
  • a (类数组) – 要计算平均值的数组

  • axis (Axis | None) – 一个可选的整数或整数序列,指定计算平均值的轴。如果未指定,则沿所有轴计算平均值。

  • weights (类数组 | None | None) – 一个可选的加权平均值的权重数组。必须与 a 广播兼容。

  • returned (bool) – 如果为 False(默认),则仅返回平均值。如果为 True,则返回平均值和归一化因子(即权重之和)。

  • keepdims (bool) – 如果为 True,则在结果中保留大小为 1 的缩减轴。如果为 False(默认),则会挤出缩减轴。

返回:

一个数组 average 或数组元组 (average, normalization) ,如果 returned 为 True。

返回类型:

Array | tuple[Array, Array]

另请参阅

示例

简单平均值

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

加权平均值

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

使用 returned=True 可选择返回归一化因子,即权重之和

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

沿指定轴的加权平均值

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)