jax.numpy.histogramdd#

jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[源代码]#

计算 N 维直方图。

numpy.histogramdd() 的 JAX 实现。

参数:
  • sample (ArrayLike) – 输入数组,形状为 (N, D),表示 D 维中的 N 个点。

  • bins (ArrayLike | list[ArrayLike]) – 指定直方图每个维度的箱子数量。(默认值:10)。也可以是长度为 D 的整数序列或箱子边缘数组。

  • range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 一个长度为 D 的序列,指定每个维度的范围。如果没有指定,范围将从数据中推断出来。

  • weights (ArrayLike | None | None) – 一个可选的形状为 (N,) 的数组,指定数据点的权重。应该与 sample 的形状相同。如果没有指定,则每个数据点的权重相等。

  • density (bool | None | None) – 如果为 True,返回单位体积的归一化直方图(单位为计数/单位体积)。如果为 False(默认),返回每个 bin 的(加权)计数。

返回:

一个数组的元组 (histogram, bin_edges),其中 histogram 包含聚合数据,bin_edges 指定 bin 的边界。

返回类型:

tuple[Array, list[Array]]

另请参阅

示例

一个三维空间中 100 个点的直方图

>>> key = jax.random.key(42)
>>> a = jax.random.normal(key, (100, 3))
>>> counts, bin_edges = jnp.histogramdd(a, bins=6,
...                                     range=[(-3, 3), (-3, 3), (-3, 3)])
>>> counts.shape
(6, 6, 6)
>>> bin_edges  
[Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32)]

使用 density=True 返回归一化直方图

>>> density, bin_edges = jnp.histogramdd(a, density=True)
>>> bin_widths = map(jnp.diff, bin_edges)
>>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij')
>>> normed = jnp.sum(density * dx * dy * dz)
>>> jnp.allclose(normed, 1.0)
Array(True, dtype=bool)