jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[源代码]#
沿着新轴连接数组。
numpy.stack()
的 JAX 实现。- 参数:
- 返回:
堆叠的结果。
- 返回类型:
另请参阅
jax.numpy.unstack()
:stack
的反操作。jax.numpy.concatenate()
: 沿现有轴连接。jax.numpy.vstack()
: 垂直堆叠,即沿轴 0。jax.numpy.hstack()
: 水平堆叠,即沿轴 1。jax.numpy.dstack()
: 深度方向堆叠,即沿轴 2。jax.numpy.column_stack()
: 堆叠列。
示例
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
执行反向操作>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)