jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[源代码]#
从数组中移除一个或多个长度为1的轴。
JAX 实现的
numpy.sqeeze()
,通过jax.lax.squeeze()
实现。- 参数:
- 返回值:
移除长度为1的轴的
a
的副本。- 返回类型:
备注
与
numpy.squeeze()
不同,jax.numpy.squeeze()
将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。另请参阅
jax.numpy.expand_dims()
:squeeze
的逆操作:添加长度为 1 的维度。jax.Array.squeeze()
:通过数组方法实现的等效功能。jax.lax.squeeze()
:等效的 XLA API。jax.numpy.ravel()
:将数组展平为 1D 形状。jax.numpy.reshape()
:通用数组重塑。
示例
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
压缩所有长度为 1 的维度
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
显式指定轴时的等效操作
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
尝试压缩非单位轴会导致错误
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
为方便起见,此功能也可以通过
jax.Array.squeeze()
方法使用>>> x.squeeze() Array([0, 1, 2], dtype=int32)