jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#
沿着新轴连接数组。
JAX 实现
numpy.stack()
.- 参数:
- 返回值:
堆叠的结果。
- 返回类型:
另请参阅
jax.numpy.unstack()
:stack
的逆操作。jax.numpy.concatenate()
:沿现有轴进行连接。jax.numpy.vstack()
:垂直堆叠,即沿轴 0。jax.numpy.hstack()
:水平堆叠,即沿轴 1。jax.numpy.dstack()
:深度堆叠,即沿轴 2。
示例
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
执行逆操作>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)