jax.numpy.compress#
- jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[source]#
使用布尔条件沿给定轴压缩数组。
JAX 实现
numpy.compress()
。- 参数:
condition (ArrayLike) – 条件的一维数组。将被转换为布尔值。
a (ArrayLike) – 值的 N 维数组。
axis (int | None | None) – 要压缩的轴。如果为 None(默认值),则
a
将被展平,并且轴将被设置为 0。size (int | None | None) – 输出的可选静态大小。必须指定,以便
compress
与 JAX 变换(如jit()
或vmap()
)兼容。fill_value (ArrayLike) – 如果指定了
size
,则使用此值填充填充的条目(默认值:0)。out (None | None) – JAX 未实现。
- 返回:
一个维度为
a.ndim
的数组,沿着指定的轴压缩。- 返回类型:
另请参阅
jax.numpy.extract()
:compress
的一维版本。jax.Array.compress()
:作为数组方法的等效功能。
注释
此函数不需要
condition
和a
之间严格的形状一致性。如果condition.size > a.shape[axis]
,则condition
将被截断,如果a.shape[axis] > condition.size
,则a
将被截断。示例
压缩二维数组的行
>>> a = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> condition = jnp.array([True, False, True]) >>> jnp.compress(condition, a, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
为了方便起见,您可以等效地使用 JAX 数组的
compress()
方法>>> a.compress(condition, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
请注意,条件不需要与指定轴的形状匹配;这里我们使用长度为 3 的条件压缩列。条件大小之外的值将被忽略
>>> jnp.compress(condition, a, axis=1) Array([[ 1, 3], [ 5, 7], [ 9, 11]], dtype=int32)
可选的
size
参数允许您指定静态输出大小,以便输出具有静态形状,因此此函数可与jit()
和vmap()
等转换一起使用。>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) >>> mask = (a % 3 == 0) >>> jax.vmap(f)(mask, a) Array([[ 3, 0, 0, 0], [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32)