jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[source]#
从数组中移除一个或多个长度为 1 的轴
JAX 实现
numpy.sqeeze()
,通过jax.lax.squeeze()
实现。- 参数:
- 返回值:
移除长度为 1 的轴后的
a
的副本。- 返回类型:
注释
与
numpy.squeeze()
不同,jax.numpy.squeeze()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会影响性能。另请参阅
jax.numpy.expand_dims()
:squeeze
的反函数:添加长度为 1 的维度。jax.Array.squeeze()
: 通过数组方法实现等效功能。jax.lax.squeeze()
: 等效的 XLA API。jax.numpy.ravel()
: 将数组展平成一维形状。jax.numpy.reshape()
: 通用数组重塑。
示例
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
压缩所有长度为 1 的维度
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
显式指定轴时的等效操作
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
尝试压缩非单位轴将导致错误
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
为了方便起见,此功能也可以通过
jax.Array.squeeze()
方法实现>>> x.squeeze() Array([0, 1, 2], dtype=int32)