jax.numpy.unique_inverse

jax.numpy.unique_inverse#

jax.numpy.unique_inverse(x, /, *, size=None, fill_value=None)[source]#

从 x 返回唯一值,以及索引、反向索引和计数。

JAX 实现 numpy.unique_inverse(); 这等同于调用 jax.numpy.unique() 并将 return_inverseequal_nan 设置为 True。

由于 unique_inverse 输出的大小取决于数据,因此该函数通常不兼容 jit() 和其他 JAX 变换。JAX 版本添加了可选的 size 参数,该参数必须在静态情况下指定,以便 jnp.unique 在此类上下文中使用。

参数:
  • x (ArrayLike) – 将从中提取唯一值的 N 维数组。

  • size (int | None | None) – 如果指定,则仅返回排名前 size 的排序唯一元素。如果唯一元素少于 size 指示的个数,则返回值将用 fill_value 填充。

  • fill_value (ArrayLike | None | None) – 当 size 指定并且元素少于指示的个数时,用 fill_value 填充剩余条目。默认为最小唯一值。

返回值:

  • values:

    形状为 (n_unique,) 的数组,包含来自 x 的唯一值。

  • inverse_indices:

    形状为 x.shape 的数组。包含 x 中每个值在 values 中的索引。对于 1D 输入,values[inverse_indices] 等效于 x

返回类型:

一个元组 (values, indices, inverse_indices, counts),具有以下属性

参见

示例

这里我们计算 1D 数组中的唯一值

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_inverse(x)

结果是一个 NamedTuple,包含两个命名属性。 values 属性包含数组中的唯一值

>>> result.values
Array([1, 3, 4], dtype=int32)

indices 属性包含输入数组中唯一 values 的索引

inverse_indices 属性包含输入在 values 中的索引

>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)

有关 sizefill_value 参数的示例,请参见 jax.numpy.unique()