jax.numpy.unique#
- jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None)[源代码]#
返回数组中的唯一值。
numpy.unique()
的 JAX 实现。由于
unique
的输出大小取决于数据,因此该函数通常与jit()
和其他 JAX 转换不兼容。JAX 版本添加了可选的size
参数,必须静态指定该参数才能在这些上下文中使用jnp.unique
。- 参数:
ar (类似数组) – 将从中提取唯一值的 N 维数组。
return_index (bool) – 如果为 True,则同时返回
ar
中每个值出现的索引return_inverse (bool) – 如果为 True,则同时返回可用于从唯一值重建
ar
的索引。return_counts (bool) – 如果为 True,则同时返回每个唯一值的出现次数。
axis (int | None | None) – 如果指定,则计算指定轴上的唯一值。如果为 None(默认),则在计算唯一值之前展平
ar
。equal_nan (bool) – 如果为 True,则在确定唯一性时认为 NaN 值相等。
size (int | None | None) – 如果指定,则仅返回前
size
个排序的唯一元素。如果唯一元素的数量少于size
指示的数量,则返回值将用fill_value
填充。fill_value (类似数组 | None | None) – 当指定
size
并且元素数量少于指示的数量时,用fill_value
填充其余条目。默认为最小唯一值。
- 返回值:
一个数组或数组元组,具体取决于
return_index
、return_inverse
和return_counts
的值。返回的值为unique_values
:如果
axis
为 None,则为长度为n_unique
的 1D 数组。如果指定axis
,则形状为(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])
。
unique_index
:(仅当 return_index 为 True 时返回) 形状为
(n_unique,)
的数组。包含ar
中每个唯一值第一次出现的索引。对于 1D 输入,ar[unique_index]
等效于unique_values
。
unique_inverse
:(仅当 return_inverse 为 True 时返回) 如果
axis
为 None,则为形状为(ar.size,)
的数组;如果指定axis
,则为形状为(ar.shape[axis],)
的数组。包含ar
中每个值在unique_values
中的索引。对于 1D 输入,unique_values[unique_inverse]
等效于ar
。
unique_counts
:(仅当 return_counts 为 True 时返回) 形状为
(n_unique,)
的数组。包含ar
中每个唯一值的出现次数。
另请参阅
jax.numpy.unique_counts()
:unique(arr, return_counts=True)
的快捷方式。jax.numpy.unique_inverse()
:unique(arr, return_inverse=True)
的快捷方式。jax.numpy.unique_all()
:具有所有返回值的unique
的快捷方式。jax.numpy.unique_values()
:类似于unique
,但没有可选返回值。
示例
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> jnp.unique(x) Array([1, 3, 4], dtype=int32)
JIT 编译和 size 参数
如果您在
jit()
或其他转换下尝试此操作,您将收到错误,因为输出形状是动态的>>> jax.jit(jnp.unique)(x) Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
问题在于转换后的函数的输出必须具有静态形状。为了使其正常工作,您可以传递一个静态
size
参数>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size']) >>> jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)
如果您的静态大小小于唯一值的真实数量,则它们将被截断。
>>> jit_unique(x, size=2) Array([1, 3], dtype=int32)
如果静态大小大于唯一值的真实数量,则将使用
fill_value
填充它们,默认为最小唯一值>>> jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) >>> jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)
多维唯一值
如果您将多维数组传递给
unique
,则默认情况下它将被展平>>> M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) >>> jnp.unique(M) Array([1, 2, 3], dtype=int32)
如果您传递一个
axis
关键字,您可以找到沿该轴的数组的唯一切片>>> jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)
返回索引
如果您设置
return_index=True
,则unique
返回每个唯一值第一次出现的索引>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) >>> print(values) [1 3 4] >>> print(indices) [2 0 1] >>> jnp.all(values == x[indices]) Array(True, dtype=bool)
在多个维度中,可以使用沿指定轴评估的
jax.numpy.take()
来提取唯一值>>> values, indices = jnp.unique(M, axis=0, return_index=True) >>> jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)
返回逆索引
如果您设置
return_inverse=True
,则unique
返回输入数组中每个条目的唯一值内的索引>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, inverse = jnp.unique(x, return_inverse=True) >>> print(values) [1 3 4] >>> print(inverse) [1 2 0 1 0] >>> jnp.all(values[inverse] == x) Array(True, dtype=bool)
在多个维度中,可以使用
jax.numpy.take()
重建输入>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) >>> jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)
返回计数
如果您设置
return_counts=True
,则unique
返回输入中每个唯一值的出现次数>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, counts = jnp.unique(x, return_counts=True) >>> print(values) [1 3 4] >>> print(counts) [2 2 1]
对于多维数组,这也返回一个 1D 计数数组,指示沿指定轴的出现次数
>>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) [[1 2] [2 3]] >>> print(counts) [2 1]