jax.numpy.unstack#

jax.numpy.unstack(x, /, *, axis=0)[源代码]#

沿轴拆分数组。

array_api.unstack() 的 JAX 实现。

参数:
  • x (ArrayLike) – 要解堆叠的数组。必须满足 x.ndim >= 1

  • axis (int) – 解堆叠的整数轴。必须满足 -x.ndim <= axis < x.ndim

返回值:

解堆叠的数组的元组。

返回类型:

tuple[Array, …]

另请参阅

示例

>>> arr = jnp.array([[1, 2, 3],
...                  [4, 5, 6]])
>>> arrs = jnp.unstack(arr)
>>> print(*arrs)
[1 2 3] [4 5 6]

stack() 提供此操作的逆操作

>>> jnp.stack(arrs)
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)