jax.numpy.extract

内容

jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[source]#

返回满足条件的数组元素。

JAX 实现 numpy.extract().

参数::
  • condition (ArrayLike) – 条件数组。将被转换为布尔值并展平为 1D。

  • arr (ArrayLike) – 要提取的值数组。将被展平为 1D。

  • size (int | None | None) – 输出的可选静态大小。必须指定才能使 extract 与 JAX 转换(如 jit()vmap())兼容。

  • fill_value (ArrayLike) – 如果指定了 size,则使用此值填充填充的条目(默认值:0)。

返回::

从提取的条目中获取的 1D 数组。如果指定了 size,则结果将具有形状 (size,),并用 fill_value 进行右填充。如果未指定 size,则输出形状将取决于 condition 中 True 条目的数量。

返回类型:

数组

笔记

此函数不需要 conditionarr 之间严格的形状一致性。如果 condition.size > arr.size,则 condition 将被截断,如果 arr.size > condition.size,则 arr 将被截断。

另请参阅

jax.numpy.compress(): extract 的多维版本。

示例

从 1D 数组中提取值

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最简单的情况下,这等效于布尔索引

>>> x[mask]
Array([2, 4, 6], dtype=int32)

为了与 JAX 转换一起使用,您可以传递 size 参数以指定输出的静态形状,以及可选的 fill_value,默认值为零

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

请注意,与布尔索引不同,extract 不需要数组和条件的大小之间严格一致,并将有效地将两者截断到最小大小

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)