jax.numpy.histogram2d#

jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[源代码]#

计算二维直方图。

JAX 实现的 numpy.histogram2d()

参数:
  • x (ArrayLike) – 要分箱的点的 x 值的一维数组。

  • y (ArrayLike) – 要分箱的点的 y 值的一维数组。

  • bins (类数组 | 列表[类数组]) – 指定直方图中的 bin 数量(默认值:10)。bins 也可以是一个数组,指定 bin 边缘的位置,或一对整数或一对数组,指定每个维度中的 bin 数量。

  • range (序列[None | Array | 序列[类数组]] | None | None) – 形式为 [[xmin, xmax], [ymin, ymax]] 的数组或列表对,指定每个维度中数据的范围。如果未指定,则从数据推断范围。

  • weights (类数组 | None | None) – 一个可选数组,指定数据点的权重。应与 xy 的形状相同。如果未指定,则每个数据点的权重相等。

  • density (bool | None | None) – 如果为 True,则返回单位面积计数归一化直方图。如果为 False(默认),则返回每个 bin 的(加权)计数。

返回值:

数组元组 (histogram, x_edges, y_edges),其中 histogram 包含聚合数据,x_edgesy_edges 指定 bin 的边界。

返回类型:

tuple[Array, Array, Array]

另请参阅

示例

>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18])
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8)
>>> counts.shape
(8, 8)
>>> x_edges
Array([ 1.,  4.,  7., 10., 13., 16., 19., 22., 25.], dtype=float32)
>>> y_edges
Array([ 2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

指定 bin 范围

>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5)
>>> counts.shape
(5, 5)
>>> x_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)
>>> y_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)

显式指定 bin 边缘

>>> x_edges = jnp.array([0, 10, 20, 30])
>>> y_edges = jnp.array([0, 10, 20, 30])
>>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges])
>>> counts
Array([[3, 0, 0],
       [1, 3, 0],
       [0, 1, 0]], dtype=int32)

使用 density=True 返回归一化直方图

>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True)
>>> dx = jnp.diff(x_edges)
>>> dy = jnp.diff(y_edges)
>>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :])
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)