jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码]#

返回数组中满足条件的元素。

numpy.extract() 的 JAX 实现。

参数:
  • condition (类数组) – 条件数组。将被转换为布尔值并展平为 1D。

  • arr (类数组) – 要提取的值的数组。将被展平为 1D。

  • size (int | None | None) – 输出的可选静态大小。为了使 extract 与 JAX 转换(如 jit()vmap())兼容,必须指定此参数。

  • fill_value (类数组) – 如果指定了 size,则用此值填充填充的条目(默认值:0)。

返回:

提取条目的 1D 数组。如果指定了 size,则结果将具有形状 (size,) 并用 fill_value 右填充。如果未指定 size,则输出形状将取决于 condition 中 True 条目的数量。

返回类型:

数组

备注

此函数不要求 conditionarr 之间严格的形状一致。如果 condition.size > arr.size,则 condition 将被截断;如果 arr.size > condition.size,则 arr 将被截断。

另请参阅

jax.numpy.compress()extract 的多维版本。

示例

从 1D 数组中提取值

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最简单的情况下,这等效于布尔索引

>>> x[mask]
Array([2, 4, 6], dtype=int32)

为了与 JAX 转换一起使用,您可以传递 size 参数来为输出指定静态形状,以及可选的 fill_value(默认为零)

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

请注意,与布尔索引不同,extract 不要求数组和条件的大小严格一致,并且会有效地将两者都截断为最小尺寸

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)