jax.numpy.unique#

jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None)[源代码]#

返回数组中的唯一值。

numpy.unique() 的 JAX 实现。

由于 unique 的输出大小取决于数据,因此该函数通常与 jit() 和其他 JAX 转换不兼容。JAX 版本添加了可选的 size 参数,必须静态指定该参数才能在这些上下文中使用 jnp.unique

参数:
  • ar (类似数组) – 将从中提取唯一值的 N 维数组。

  • return_index (bool) – 如果为 True,则同时返回 ar 中每个值出现的索引

  • return_inverse (bool) – 如果为 True,则同时返回可用于从唯一值重建 ar 的索引。

  • return_counts (bool) – 如果为 True,则同时返回每个唯一值的出现次数。

  • axis (int | None | None) – 如果指定,则计算指定轴上的唯一值。如果为 None(默认),则在计算唯一值之前展平 ar

  • equal_nan (bool) – 如果为 True,则在确定唯一性时认为 NaN 值相等。

  • size (int | None | None) – 如果指定,则仅返回前 size 个排序的唯一元素。如果唯一元素的数量少于 size 指示的数量,则返回值将用 fill_value 填充。

  • fill_value (类似数组 | None | None) – 当指定 size 并且元素数量少于指示的数量时,用 fill_value 填充其余条目。默认为最小唯一值。

返回值:

一个数组或数组元组,具体取决于 return_indexreturn_inversereturn_counts 的值。返回的值为

  • unique_values:

    如果 axis 为 None,则为长度为 n_unique 的 1D 数组。如果指定 axis,则形状为 (*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])

  • unique_index:

    (仅当 return_index 为 True 时返回) 形状为 (n_unique,) 的数组。包含 ar 中每个唯一值第一次出现的索引。对于 1D 输入,ar[unique_index] 等效于 unique_values

  • unique_inverse:

    (仅当 return_inverse 为 True 时返回) 如果 axis 为 None,则为形状为 (ar.size,) 的数组;如果指定 axis,则为形状为 (ar.shape[axis],) 的数组。包含 ar 中每个值在 unique_values 中的索引。对于 1D 输入,unique_values[unique_inverse] 等效于 ar

  • unique_counts:

    (仅当 return_counts 为 True 时返回) 形状为 (n_unique,) 的数组。包含 ar 中每个唯一值的出现次数。

另请参阅

示例

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique(x)
Array([1, 3, 4], dtype=int32)

JIT 编译和 size 参数

如果您在 jit() 或其他转换下尝试此操作,您将收到错误,因为输出形状是动态的

>>> jax.jit(jnp.unique)(x)  
Traceback (most recent call last):
   ...
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5].
The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.

问题在于转换后的函数的输出必须具有静态形状。为了使其正常工作,您可以传递一个静态 size 参数

>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size'])
>>> jit_unique(x, size=3)
Array([1, 3, 4], dtype=int32)

如果您的静态大小小于唯一值的真实数量,则它们将被截断。

>>> jit_unique(x, size=2)
Array([1, 3], dtype=int32)

如果静态大小大于唯一值的真实数量,则将使用 fill_value 填充它们,默认为最小唯一值

>>> jit_unique(x, size=5)
Array([1, 3, 4, 1, 1], dtype=int32)
>>> jit_unique(x, size=5, fill_value=0)
Array([1, 3, 4, 0, 0], dtype=int32)

多维唯一值

如果您将多维数组传递给 unique,则默认情况下它将被展平

>>> M = jnp.array([[1, 2],
...                [2, 3],
...                [1, 2]])
>>> jnp.unique(M)
Array([1, 2, 3], dtype=int32)

如果您传递一个 axis 关键字,您可以找到沿该轴的数组的唯一切片

>>> jnp.unique(M, axis=0)
Array([[1, 2],
       [2, 3]], dtype=int32)

返回索引

如果您设置 return_index=True,则 unique 返回每个唯一值第一次出现的索引

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, indices = jnp.unique(x, return_index=True)
>>> print(values)
[1 3 4]
>>> print(indices)
[2 0 1]
>>> jnp.all(values == x[indices])
Array(True, dtype=bool)

在多个维度中,可以使用沿指定轴评估的 jax.numpy.take() 来提取唯一值

>>> values, indices = jnp.unique(M, axis=0, return_index=True)
>>> jnp.all(values == jnp.take(M, indices, axis=0))
Array(True, dtype=bool)

返回逆索引

如果您设置 return_inverse=True,则 unique 返回输入数组中每个条目的唯一值内的索引

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, inverse = jnp.unique(x, return_inverse=True)
>>> print(values)
[1 3 4]
>>> print(inverse)
[1 2 0 1 0]
>>> jnp.all(values[inverse] == x)
Array(True, dtype=bool)

在多个维度中,可以使用 jax.numpy.take() 重建输入

>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
Array(True, dtype=bool)

返回计数

如果您设置 return_counts=True,则 unique 返回输入中每个唯一值的出现次数

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, counts = jnp.unique(x, return_counts=True)
>>> print(values)
[1 3 4]
>>> print(counts)
[2 2 1]

对于多维数组,这也返回一个 1D 计数数组,指示沿指定轴的出现次数

>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
>>> print(values)
[[1 2]
 [2 3]]
>>> print(counts)
[2 1]