jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[源代码]#

从数组中移除一个或多个长度为1的轴。

JAX 实现的 numpy.sqeeze(),通过 jax.lax.squeeze() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | Sequence[int] | None | None) – 指定要移除的轴的整数或整数序列。如果任何指定的轴的长度不为1,则会引发错误。如果未指定,则压缩 a 中所有长度为1的轴。

返回值:

移除长度为1的轴的 a 的副本。

返回类型:

Array

备注

numpy.squeeze() 不同,jax.numpy.squeeze() 将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。

另请参阅

示例

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

压缩所有长度为 1 的维度

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

显式指定轴时的等效操作

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

尝试压缩非单位轴会导致错误

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

为方便起见,此功能也可以通过 jax.Array.squeeze() 方法使用

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)