jax.numpy.histogramdd#
- jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[源代码]#
计算 N 维直方图。
numpy.histogramdd()
的 JAX 实现。- 参数:
sample (ArrayLike) – 输入数组,形状为
(N, D)
,表示D
维中的N
个点。bins (ArrayLike | list[ArrayLike]) – 指定直方图每个维度的箱子数量。(默认值:10)。也可以是长度为 D 的整数序列或箱子边缘数组。
range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 一个长度为 D 的序列,指定每个维度的范围。如果没有指定,范围将从数据中推断出来。
weights (ArrayLike | None | None) – 一个可选的形状为
(N,)
的数组,指定数据点的权重。应该与sample
的形状相同。如果没有指定,则每个数据点的权重相等。density (bool | None | None) – 如果为 True,返回单位体积的归一化直方图(单位为计数/单位体积)。如果为 False(默认),返回每个 bin 的(加权)计数。
- 返回:
一个数组的元组
(histogram, bin_edges)
,其中histogram
包含聚合数据,bin_edges
指定 bin 的边界。- 返回类型:
另请参阅
jax.numpy.histogram()
: 计算一维数组的直方图。jax.numpy.histogram2d()
: 计算二维数组的直方图。jax.numpy.histogram_bin_edges()
: 计算直方图的 bin 边界。
示例
一个三维空间中 100 个点的直方图
>>> key = jax.random.key(42) >>> a = jax.random.normal(key, (100, 3)) >>> counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) >>> counts.shape (6, 6, 6) >>> bin_edges [Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
使用
density=True
返回归一化直方图>>> density, bin_edges = jnp.histogramdd(a, density=True) >>> bin_widths = map(jnp.diff, bin_edges) >>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') >>> normed = jnp.sum(density * dx * dy * dz) >>> jnp.allclose(normed, 1.0) Array(True, dtype=bool)