jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[源代码]#
计算两个一维数组的集合交集。
numpy.intersect1d()
的 JAX 实现。由于
intersect1d
的输出大小取决于数据,因此该函数通常与jit()
和其他 JAX 转换不兼容。JAX 版本添加了可选的size
参数,必须静态指定该参数才能在这些上下文中使用jnp.intersect1d
。- 参数:
ar1 (ArrayLike) – 第一个要计算交集的数组。
ar2 (ArrayLike) – 第二个要计算交集的数组。
assume_unique (bool) – 如果为 True,则假设输入数组包含唯一值。这允许更高效的实现,但如果
assume_unique
为 True 且输入数组包含重复项,则行为未定义。默认值:False。return_indices (bool) – 如果为 True,则返回索引数组,指定交集值首次出现在输入数组中的位置。
size (int | None | None) – 如果指定,则仅返回前
size
个排序元素。如果元素少于size
指示的元素,则返回值将用fill_value
填充,并且返回的索引将用超出范围的索引填充。fill_value (ArrayLike | None | None) – 当指定
size
且元素少于指示的数量时,用fill_value
填充剩余条目。默认为交集中的最小值。
- 返回值:
一个数组
intersection
,或者如果return_indices=True
,则返回数组元组(intersection, ar1_indices, ar2_indices)
。返回的值为intersection
: 一个一维数组,包含同时出现在ar1
和ar2
中的每个值。ar1_indices
: (如果 return_indices=True 则返回) 一个形状为intersection.shape
的数组,包含intersection
中值在扁平化后的ar1
中的索引。对于一维输入,intersection
等效于ar1[ar1_indices]
。ar2_indices
: (如果 return_indices=True 则返回) 一个形状为intersection.shape
的数组,包含intersection
中值在扁平化后的ar2
中的索引。对于一维输入,intersection
等效于ar2[ar2_indices]
。
- 返回类型:
另请参阅
jax.numpy.union1d()
:两个一维数组的集合并集。jax.numpy.setxor1d()
:两个一维数组的集合异或。jax.numpy.setdiff1d()
:两个一维数组的集合差。
示例
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
计算带索引的交集
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
给出了交集值在ar1
中的索引>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
给出了交集值在ar2
中的索引>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)