jax.numpy.histogram#

jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[源代码]#

计算一维直方图。

JAX 实现的 numpy.histogram()

参数:
  • a (ArrayLike) – 要分箱的值的数组。可以是任何大小或维度。

  • bins (ArrayLike) – 指定直方图中的箱子数量(默认值:10)。bins 也可以是一个数组,指定箱子边缘的位置。

  • range (Sequence[ArrayLike] | None | None) – 标量的元组。指定数据的范围。如果未指定,则从数据推断范围。

  • weights (ArrayLike | None | None) – 一个可选的数组,指定数据点的权重。应与 a 广播兼容。如果未指定,则每个数据点的权重相等。

  • density (bool | None | None) – 如果为 True,则返回单位长度的计数归一化直方图。如果为 False (默认),则返回每个箱子的(加权)计数。

返回值:

数组的元组 (histogram, bin_edges),其中 histogram 包含聚合数据,而 bin_edges 指定箱子的边界。

返回类型:

tuple[Array, Array]

另请参阅

示例

>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> counts, bin_edges = jnp.histogram(a, bins=8)
>>> print(counts)
[3. 0. 0. 2. 1. 0. 1. 1.]
>>> print(bin_edges)
[ 1.  4.  7. 10. 13. 16. 19. 22. 25.]

指定箱子范围

>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5)
>>> print(counts)
[3. 0. 2. 2. 1.]
>>> print(bin_edges)
[ 0.  5. 10. 15. 20. 25.]

显式指定箱子边缘

>>> bin_edges = jnp.array([0, 10, 20, 30])
>>> counts, _ = jnp.histogram(a, bins=bin_edges)
>>> print(counts)
[3. 4. 1.]

使用 density=True 返回归一化直方图

>>> density, bin_edges = jnp.histogram(a, density=True)
>>> dx = jnp.diff(bin_edges)
>>> normed_sum = jnp.sum(density * dx)
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)