jax.numpy.histogram#
- jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[源代码]#
计算一维直方图。
numpy.histogram()
的 JAX 实现。- 参数:
a (类数组) – 要分箱的值数组。可以是任意大小或维度。
bins (类数组) – 指定直方图中的箱数 (默认:10)。
bins
也可以是一个数组,指定箱边缘的位置。range (Sequence[ArrayLike] | None | None) – 标量元组。指定数据的范围。如果未指定,则从数据推断范围。
weights (类数组 | None | None) – 一个可选的数组,用于指定数据点的权重。应与
a
兼容广播。如果未指定,则每个数据点的权重相同。density (bool | None | None) – 如果为 True,则返回单位长度内计数归一化的直方图。如果为 False (默认),则返回每个 bin 的(加权)计数。
- 返回值:
一个数组元组
(histogram, bin_edges)
,其中histogram
包含聚合的数据,而bin_edges
指定 bin 的边界。- 返回类型:
另请参阅
jax.numpy.bincount()
: 计算数组中每个值的出现次数。jax.numpy.histogram2d()
: 计算二维数组的直方图。jax.numpy.histogramdd()
: 计算 N 维数组的直方图。jax.numpy.histogram_bin_edges()
: 计算直方图的 bin 边界。
示例
>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> counts, bin_edges = jnp.histogram(a, bins=8) >>> print(counts) [3. 0. 0. 2. 1. 0. 1. 1.] >>> print(bin_edges) [ 1. 4. 7. 10. 13. 16. 19. 22. 25.]
指定 bin 的范围
>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5) >>> print(counts) [3. 0. 2. 2. 1.] >>> print(bin_edges) [ 0. 5. 10. 15. 20. 25.]
显式指定 bin 的边界
>>> bin_edges = jnp.array([0, 10, 20, 30]) >>> counts, _ = jnp.histogram(a, bins=bin_edges) >>> print(counts) [3. 4. 1.]
使用
density=True
返回归一化的直方图>>> density, bin_edges = jnp.histogram(a, density=True) >>> dx = jnp.diff(bin_edges) >>> normed_sum = jnp.sum(density * dx) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)