jax.numpy.stack

内容

jax.numpy.stack#

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#

沿着新轴连接数组。

JAX 实现 numpy.stack().

参数:
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – 要堆叠的数组序列;每个数组都必须具有相同的形状。如果给定单个数组,它将等效于 arrays = unstack(arrays),但实现将避免显式拆分。

  • axis (int) – 指定要堆叠的轴。

  • out (None | None) – JAX 未使用

  • dtype (DTypeLike | None | None) – 结果数组的可选数据类型。如果未指定,则数据类型将通过在 类型提升语义 中描述的类型提升规则确定。

返回值:

堆叠的结果。

返回类型:

Array

另请参阅

示例

>>> x = jnp.array([1, 2, 3])
>>> y = jnp.array([4, 5, 6])
>>> jnp.stack([x, y])
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.stack([x, y], axis=1)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

unstack() 执行逆操作

>>> arr = jnp.stack([x, y], axis=1)
>>> x, y = jnp.unstack(arr, axis=1)
>>> x
Array([1, 2, 3], dtype=int32)
>>> y
Array([4, 5, 6], dtype=int32)