jax.numpy.squeeze

内容

jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[source]#

从数组中移除一个或多个长度为 1 的轴

JAX 实现 numpy.sqeeze(),通过 jax.lax.squeeze() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | Sequence[int] | None | None) – 整数或整数序列,指定要移除的轴。如果任何指定的轴长度不为 1,则会引发错误。如果未指定,则压缩 a 中所有长度为 1 的轴。

返回值:

移除长度为 1 的轴后的 a 的副本。

返回类型:

Array

注释

numpy.squeeze() 不同,jax.numpy.squeeze() 将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会影响性能。

另请参阅

示例

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

压缩所有长度为 1 的维度

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

显式指定轴时的等效操作

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

尝试压缩非单位轴将导致错误

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

为了方便起见,此功能也可以通过 jax.Array.squeeze() 方法实现

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)