jax.numpy.unique_values#

jax.numpy.unique_values(x, /, *, size=None, fill_value=None)[源代码]#

返回 x 中的唯一值,以及索引、逆索引和计数。

numpy.unique_values() 的 JAX 实现;这等价于调用 jax.numpy.unique() 并将 equal_nan 设置为 True。

由于 unique_values 的输出大小取决于数据,该函数通常与 jit() 和其他 JAX 转换不兼容。 JAX 版本添加了可选的 size 参数,该参数必须静态指定,以便在这些情况下使用 jnp.unique

参数:
  • x (ArrayLike) – 从中提取唯一值的 N 维数组。

  • size (int | None | None) – 如果指定,则仅返回前 size 个排序后的唯一元素。 如果唯一元素的数量少于 size 指示的数量,则返回值将使用 fill_value 填充。

  • fill_value (ArrayLike | None | None) – 当指定了 size 且元素数量少于指示的数量时,将剩余条目填充为 fill_value。 默认为最小唯一值。

返回:

一个形状为 (n_unique,) 的数组 values,其中包含来自 x 的唯一值。

返回类型:

数组

另请参阅

示例

这里我们计算一个一维数组中的唯一值

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique_values(x)
Array([1, 3, 4], dtype=int32)

有关 sizefill_value 参数的示例,请参阅 jax.numpy.unique()