jax.numpy.dstack#
- jax.numpy.dstack(tup, dtype=None)[源代码]#
沿深度方向堆叠数组。
JAX 对
numpy.dstack()
的实现。对于三维或更高维度的数组,这等效于使用
axis=2
的jax.numpy.concatenate()
。- 参数:
- 返回:
堆叠的结果。
- 返回类型:
另请参阅
jax.numpy.stack()
:沿任意轴堆叠jax.numpy.concatenate()
:沿现有轴拼接。jax.numpy.vstack()
:垂直堆叠,即沿轴 0 堆叠。jax.numpy.hstack()
:水平堆叠,即沿轴 1 堆叠。
示例
标量值
>>> jnp.dstack([1, 2, 3]) Array([[[1, 2, 3]]], dtype=int32, weak_type=True)
一维数组
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.dstack([x, y]) Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)
二维数组
>>> x = x.reshape(1, 3) >>> y = y.reshape(1, 3) >>> jnp.dstack([x, y]) Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)