jax.numpy.vstack#
- jax.numpy.vstack(tup, dtype=None)[source]#
垂直堆叠数组。
JAX 实现
numpy.vstack()
.对于两个或更多维的数组,这等效于
jax.numpy.concatenate()
且axis=0
.- 参数:
- 返回值:
堆叠后的结果。
- 返回类型:
另请参阅
jax.numpy.stack()
: 沿任意轴堆叠jax.numpy.concatenate()
: 沿现有轴连接。jax.numpy.hstack()
: 水平堆叠,即沿轴 1 堆叠。jax.numpy.dstack()
: 深度堆叠,即沿轴 2 堆叠。
示例
标量值
>>> jnp.vstack([1, 2, 3]) Array([[1], [2], [3]], dtype=int32, weak_type=True)
一维数组
>>> x = jnp.arange(4) >>> y = jnp.ones(4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)
二维数组
>>> x = x.reshape(1, 4) >>> y = y.reshape(1, 4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)