jax.numpy.union1d

内容

jax.numpy.union1d#

jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]#

计算两个一维数组的并集。

JAX 实现 numpy.union1d().

由于 union1d 输出的大小取决于数据,因此该函数通常与 jit() 和其他 JAX 变换不兼容。JAX 版本添加了可选的 size 参数,该参数必须静态指定,才能在这些上下文中使用 jnp.union1d

参数:
  • ar1 (ArrayLike) – 要合并的第一个元素数组。

  • ar2 (ArrayLike) – 要合并的第二个元素数组

  • size (int | None | None) – 如果指定,则仅返回前 size 个排序元素。如果元素数量少于 size 所指示的数量,则返回值将用 fill_value 填充。

  • fill_value (ArrayLike | None | None) – 当指定 size 且元素数量少于指示的数量时,用 fill_value 填充剩余条目。默认值为最小值。

返回值:

包含输入数组中元素的并集的数组。

返回类型:

数组

参见

示例

计算两个数组的并集

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.union1d(ar1, ar2)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

由于输出形状是动态的,这将在 jit() 和其他转换下失败

>>> jax.jit(jnp.union1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

为了确保静态已知的输出形状,您可以传递一个静态的 size 参数

>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
>>> jit_union1d(ar1, ar2, size=6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

如果 size 太小,则并集将被截断

>>> jit_union1d(ar1, ar2, size=4)
Array([1, 2, 3, 4], dtype=int32)

如果 size 太大,则输出将用 fill_value 填充

>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)