jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码]#
返回数组中满足条件的元素。
numpy.extract()
的 JAX 实现。- 参数:
- 返回:
提取条目的 1D 数组。如果指定了
size
,则结果将具有形状(size,)
并用fill_value
右填充。如果未指定size
,则输出形状将取决于condition
中 True 条目的数量。- 返回类型:
备注
此函数不要求
condition
和arr
之间严格的形状一致。如果condition.size > arr.size
,则condition
将被截断;如果arr.size > condition.size
,则arr
将被截断。另请参阅
jax.numpy.compress()
:extract
的多维版本。示例
从 1D 数组中提取值
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最简单的情况下,这等效于布尔索引
>>> x[mask] Array([2, 4, 6], dtype=int32)
为了与 JAX 转换一起使用,您可以传递
size
参数来为输出指定静态形状,以及可选的fill_value
(默认为零)>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
请注意,与布尔索引不同,
extract
不要求数组和条件的大小严格一致,并且会有效地将两者都截断为最小尺寸>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)