jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码]#
返回数组中满足条件的元素。
JAX 实现的
numpy.extract()
。- 参数:
- 返回值:
提取的条目的 1D 数组。如果指定了
size
,则结果的形状将为(size,)
,并用fill_value
在右侧进行填充。如果未指定size
,则输出形状将取决于condition
中 True 条目的数量。- 返回类型:
笔记
此函数不要求
condition
和arr
之间具有严格的形状一致性。如果condition.size > arr.size
,则将截断condition
;如果arr.size > condition.size
,则将截断arr
。另请参阅
jax.numpy.compress()
:extract
的多维版本。示例
从 1D 数组中提取值
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最简单的情况下,这等效于布尔索引
>>> x[mask] Array([2, 4, 6], dtype=int32)
为了与 JAX 转换一起使用,您可以传递
size
参数来指定输出的静态形状,以及一个可选的fill_value
,默认为零>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
请注意,与布尔索引不同,
extract
不要求数组和条件的大小严格一致,并且会有效地将两者都截断为最小的大小>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)