jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[源代码]#

从数组中删除一个或多个长度为 1 的轴

numpy.sqeeze() 的 JAX 实现,通过 jax.lax.squeeze() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | Sequence[int] | None | None) – 指定要删除的轴的整数或整数序列。如果任何指定的轴的长度不是 1,则会引发错误。如果未指定,则会压缩 a 中所有长度为 1 的轴。

返回:

删除长度为 1 的轴后的 a 的副本。

返回类型:

Array

备注

numpy.squeeze() 不同,jax.numpy.squeeze() 将返回输入数组的副本,而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此实际上不会对性能产生影响。

另请参阅

示例

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

压缩所有长度为 1 的维度

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

显式指定轴时的等效操作

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

尝试压缩非单位轴会导致错误

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

为了方便起见,此功能也可通过 jax.Array.squeeze() 方法获得

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)