jax.numpy.compress#

jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[源代码]#

使用布尔条件沿给定轴压缩数组。

numpy.compress() 的 JAX 实现。

参数:
  • condition (类数组) – 一维条件数组。将会被转换为布尔值。

  • a (类数组) – N 维数值数组。

  • axis (int | None | None) – 压缩所沿的轴。如果为 None (默认值),则 a 将被展平,并且轴将被设置为 0。

  • size (int | None | None) – 输出的可选静态大小。必须指定此参数,才能使 compress 与 JAX 转换(如 jit()vmap())兼容。

  • fill_value (类数组) – 如果指定了 size,则用此值填充补齐的条目 (默认值:0)。

  • out (None | None) – JAX 未实现。

返回值:

一个维度为 a.ndim 的数组,沿指定轴压缩。

返回类型:

数组

另请参阅

注释

此函数不要求 conditiona 之间严格的形状一致性。如果 condition.size > a.shape[axis],则 condition 将被截断;如果 a.shape[axis] > condition.size,则 a 将被截断。

示例

沿二维数组的行压缩

>>> a = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9,  10, 11, 12]])
>>> condition = jnp.array([True, False, True])
>>> jnp.compress(condition, a, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

为方便起见,您可以等效地使用 JAX 数组的 compress() 方法

>>> a.compress(condition, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

请注意,条件不必与指定轴的形状匹配;这里我们使用长度为 3 的条件压缩列。超出条件大小的值将被忽略

>>> jnp.compress(condition, a, axis=1)
Array([[ 1,  3],
       [ 5,  7],
       [ 9, 11]], dtype=int32)

可选的 size 参数允许您指定静态输出大小,以便输出是静态形状的,因此此函数可以与 jit()vmap() 等变换一起使用

>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0)
>>> mask = (a % 3 == 0)
>>> jax.vmap(f)(mask, a)
Array([[ 3,  0,  0,  0],
       [ 6,  0,  0,  0],
       [ 9, 12,  0,  0]], dtype=int32)