jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码]#

返回数组中满足条件的元素。

JAX 实现的 numpy.extract()

参数:
  • condition (ArrayLike) – 条件数组。将被转换为布尔值并展平为 1D 数组。

  • arr (ArrayLike) – 要提取的值的数组。将被展平为 1D 数组。

  • size (int | None | None) – 可选的输出静态大小。为了使 extract 与 JAX 转换(如 jit()vmap())兼容,必须指定此参数。

  • fill_value (ArrayLike) – 如果指定了 size,则用此值填充补齐的条目(默认为 0)。

返回值:

提取的条目的 1D 数组。如果指定了 size,则结果的形状将为 (size,),并用 fill_value 在右侧进行填充。如果未指定 size,则输出形状将取决于 condition 中 True 条目的数量。

返回类型:

数组

笔记

此函数不要求 conditionarr 之间具有严格的形状一致性。如果 condition.size > arr.size,则将截断 condition;如果 arr.size > condition.size,则将截断 arr

另请参阅

jax.numpy.compress()extract 的多维版本。

示例

从 1D 数组中提取值

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最简单的情况下,这等效于布尔索引

>>> x[mask]
Array([2, 4, 6], dtype=int32)

为了与 JAX 转换一起使用,您可以传递 size 参数来指定输出的静态形状,以及一个可选的 fill_value,默认为零

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

请注意,与布尔索引不同,extract 不要求数组和条件的大小严格一致,并且会有效地将两者都截断为最小的大小

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)