jax.numpy.bincount#
- jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[源代码]#
计算整数数组中每个值的出现次数。
JAX 实现的
numpy.bincount()
。对于正整数数组
x
,此函数返回一个大小为x.max() + 1
的数组counts
,使得counts[i]
包含值i
在x
中出现的次数。JAX 版本与 NumPy 版本有一些差异
在 NumPy 中,传递一个带有负数条目的数组
x
将导致错误。在 JAX 中,负值会被截断为零。JAX 添加了一个可选的
length
参数,该参数可用于静态指定输出数组的长度,以便此函数可以与jax.jit()
等转换一起使用。在这种情况下,大于 length + 1 的项目将被删除。
- 参数:
- 返回:
一个计数数组或求和权重数组,反映
x
中值的出现次数。- 返回类型:
示例
基本 bincount
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
加权 bincount
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
指定静态
length
使其与 jit 兼容>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
任何负数都会被截断到第一个 bin,并且超出指定
length
的数字将被删除>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)