jax.numpy.histogram#

jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[源代码]#

计算一维直方图。

numpy.histogram() 的 JAX 实现。

参数:
  • a (类数组) – 要分箱的值数组。可以是任意大小或维度。

  • bins (类数组) – 指定直方图中的箱数 (默认:10)。 bins 也可以是一个数组,指定箱边缘的位置。

  • range (Sequence[ArrayLike] | None | None) – 标量元组。指定数据的范围。如果未指定,则从数据推断范围。

  • weights (类数组 | None | None) – 一个可选的数组,用于指定数据点的权重。应与 a 兼容广播。如果未指定,则每个数据点的权重相同。

  • density (bool | None | None) – 如果为 True,则返回单位长度内计数归一化的直方图。如果为 False (默认),则返回每个 bin 的(加权)计数。

返回值:

一个数组元组 (histogram, bin_edges),其中 histogram 包含聚合的数据,而 bin_edges 指定 bin 的边界。

返回类型:

tuple[Array, Array]

另请参阅

示例

>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> counts, bin_edges = jnp.histogram(a, bins=8)
>>> print(counts)
[3. 0. 0. 2. 1. 0. 1. 1.]
>>> print(bin_edges)
[ 1.  4.  7. 10. 13. 16. 19. 22. 25.]

指定 bin 的范围

>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5)
>>> print(counts)
[3. 0. 2. 2. 1.]
>>> print(bin_edges)
[ 0.  5. 10. 15. 20. 25.]

显式指定 bin 的边界

>>> bin_edges = jnp.array([0, 10, 20, 30])
>>> counts, _ = jnp.histogram(a, bins=bin_edges)
>>> print(counts)
[3. 4. 1.]

使用 density=True 返回归一化的直方图

>>> density, bin_edges = jnp.histogram(a, density=True)
>>> dx = jnp.diff(bin_edges)
>>> normed_sum = jnp.sum(density * dx)
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)