jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[source]#
在数组中插入长度为 1 的维度
JAX 实现的
numpy.expand_dims()
,通过jax.lax.expand_dims()
实现。- 参数:
- 返回值:
带有新增维度的
a
的副本。- 返回类型:
备注
与
numpy.expand_dims()
不同,jax.numpy.expand_dims()
将返回输入数组的副本而不是视图。但是,在 JIT 下,编译器将在可能的情况下优化掉此类副本,因此在实践中不会对性能产生影响。另请参阅
jax.numpy.squeeze()
:此操作的反操作,即移除长度为 1 的维度。jax.lax.expand_dims()
:此功能的 XLA 版本。
示例
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
扩展前导维度
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
扩展尾随维度
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
扩展多个维度
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
也可以通过使用
None
索引更简洁地扩展维度>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)