jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#
计算两个一维数组的交集。
JAX 实现的
numpy.intersect1d()
。由于
intersect1d
输出的大小依赖于数据,因此该函数通常与jit()
和其他 JAX 变换不兼容。JAX 版本添加了可选的size
参数,该参数必须静态指定才能在这些上下文中使用jnp.intersect1d
。- 参数:
ar1 (ArrayLike) – 要进行交集运算的第一个数组。
ar2 (ArrayLike) – 要进行交集运算的第二个数组。
assume_unique (bool) – 如果为 True,则假设输入数组包含唯一值。这允许更有效的实现,但如果
assume_unique
为 True 且输入数组包含重复值,则行为未定义。默认值:False。return_indices (bool) – 如果为 True,则返回指定交集值在输入数组中首次出现位置的索引数组。
size (int | None | None) – 如果指定,则仅返回前
size
个排序元素。如果元素少于size
指示的数量,则返回值将用fill_value
填充,返回的索引将用超出范围的索引填充。fill_value (ArrayLike | None | None) – 当
size
指定且元素少于指示的数量时,用fill_value
填充剩余的条目。默认为交集中最小的值。
- 返回值:
一个数组
intersection
,或者如果return_indices=True
,则为一个数组元组(intersection, ar1_indices, ar2_indices)
。返回值为intersection
:一个 1D 数组,包含同时出现在ar1
和ar2
中的每个值。ar1_indices
:(如果 return_indices=True 则返回) 形状为intersection.shape
的数组,包含intersection
中值在扁平化ar1
中的索引。对于 1D 输入,intersection
等效于ar1[ar1_indices]
。ar2_indices
:(如果 return_indices=True 则返回) 形状为intersection.shape
的数组,包含intersection
中值在扁平化ar2
中的索引。对于 1D 输入,intersection
等效于ar2[ar2_indices]
。
- 返回类型:
另请参见
jax.numpy.union1d()
:两个 1D 数组的并集。jax.numpy.setxor1d()
:两个 1D 数组的异或集。jax.numpy.setdiff1d()
:两个 1D 数组的差集。
示例
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
计算交集和索引
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
给出交集值在ar1
中的索引>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
给出交集值在ar2
中的索引>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)