jax.numpy.average#
- jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[源代码]#
计算加权平均值。
numpy.average()
的 JAX 实现。- 参数:
- 返回:
一个数组
average
或数组元组(average, normalization)
,如果returned
为 True。- 返回类型:
另请参阅
jax.numpy.mean()
:未加权的平均值。
示例
简单平均值
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
加权平均值
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
使用
returned=True
可选择返回归一化因子,即权重之和>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
沿指定轴的加权平均值
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)