jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[source]#
返回满足条件的数组元素。
JAX 实现
numpy.extract()
.- 参数::
- 返回::
从提取的条目中获取的 1D 数组。如果指定了
size
,则结果将具有形状(size,)
,并用fill_value
进行右填充。如果未指定size
,则输出形状将取决于condition
中 True 条目的数量。- 返回类型:
笔记
此函数不需要
condition
和arr
之间严格的形状一致性。如果condition.size > arr.size
,则condition
将被截断,如果arr.size > condition.size
,则arr
将被截断。另请参阅
jax.numpy.compress()
:extract
的多维版本。示例
从 1D 数组中提取值
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最简单的情况下,这等效于布尔索引
>>> x[mask] Array([2, 4, 6], dtype=int32)
为了与 JAX 转换一起使用,您可以传递
size
参数以指定输出的静态形状,以及可选的fill_value
,默认值为零>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
请注意,与布尔索引不同,
extract
不需要数组和条件的大小之间严格一致,并将有效地将两者截断到最小大小>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)