jax.numpy.digitize#
- jax.numpy.digitize(x, bins, right=False, *, method=None)[源代码]#
将数组转换为 bin 索引。
numpy.digitize()
的 JAX 实现。- 参数:
x (ArrayLike) – 要数字化的值的数组。
bins (ArrayLike) – 1D bin 边缘数组。必须单调递增或递减。
right (bool) – 如果为 true,则区间包括右侧 bin 边缘。如果为 false(默认),则区间包括左侧 bin 边缘。
method (str | None) – 传递给
searchsorted()
的可选方法参数。请参阅该函数以获取可用选项。
- 返回:
一个与
x
形状相同的整数数组,指示这些值所在的 bin 编号。- 返回类型:
另请参阅
jax.numpy.searchsorted()
:查找排序数组中值的插入索引。jax.numpy.histogram()
:计算指定 bin 中数组值的频率。
示例
>>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) >>> bins = jnp.array([1, 2, 3]) >>> jnp.digitize(x, bins) Array([1, 2, 2, 1, 3, 3], dtype=int32) >>> jnp.digitize(x, bins, right=True) Array([0, 1, 2, 1, 2, 3], dtype=int32)
digitize
也支持反向排序的 bin>>> bins = jnp.array([3, 2, 1]) >>> jnp.digitize(x, bins) Array([2, 1, 1, 2, 0, 0], dtype=int32)