jax.numpy.intersect1d

内容

jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#

计算两个一维数组的交集。

JAX 实现的 numpy.intersect1d()

由于intersect1d输出的大小依赖于数据,因此该函数通常与jit()和其他 JAX 变换不兼容。JAX 版本添加了可选的size参数,该参数必须静态指定才能在这些上下文中使用jnp.intersect1d

参数:
  • ar1 (ArrayLike) – 要进行交集运算的第一个数组。

  • ar2 (ArrayLike) – 要进行交集运算的第二个数组。

  • assume_unique (bool) – 如果为 True,则假设输入数组包含唯一值。这允许更有效的实现,但如果assume_unique为 True 且输入数组包含重复值,则行为未定义。默认值:False。

  • return_indices (bool) – 如果为 True,则返回指定交集值在输入数组中首次出现位置的索引数组。

  • size (int | None | None) – 如果指定,则仅返回前size个排序元素。如果元素少于size指示的数量,则返回值将用fill_value填充,返回的索引将用超出范围的索引填充。

  • fill_value (ArrayLike | None | None) – 当size指定且元素少于指示的数量时,用fill_value填充剩余的条目。默认为交集中最小的值。

返回值:

一个数组intersection,或者如果return_indices=True,则为一个数组元组(intersection, ar1_indices, ar2_indices)。返回值为

  • intersection:一个 1D 数组,包含同时出现在ar1ar2中的每个值。

  • ar1_indices(如果 return_indices=True 则返回) 形状为intersection.shape的数组,包含intersection中值在扁平化ar1中的索引。对于 1D 输入,intersection等效于ar1[ar1_indices]

  • ar2_indices(如果 return_indices=True 则返回) 形状为intersection.shape的数组,包含intersection中值在扁平化ar2中的索引。对于 1D 输入,intersection等效于ar2[ar2_indices]

返回类型:

Array | tuple[Array, Array, Array]

另请参见

示例

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

计算交集和索引

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices给出交集值在ar1中的索引

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices给出交集值在ar2中的索引

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)