jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[源代码]#
在数组中插入长度为 1 的维度
numpy.expand_dims()
的 JAX 实现,通过jax.lax.expand_dims()
实现。- 参数:
a (ArrayLike) – 输入数组
axis (int | Sequence[int]) – 指定要添加轴的位置的整数或整数序列。
- 返回:
复制
a
并添加维度。- 返回类型:
笔记
与
numpy.expand_dims()
不同,jax.numpy.expand_dims()
将返回输入数组的副本,而不是视图。但是,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会对性能产生影响。另请参阅
jax.numpy.squeeze()
: 此操作的逆操作,即移除长度为 1 的维度。jax.lax.expand_dims()
: 此功能的 XLA 版本。
示例
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
扩展前导维度
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
扩展尾部维度
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
扩展多个维度
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
也可以使用
None
进行索引,以更简洁地扩展维度>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)