jax.numpy.unstack#
- jax.numpy.unstack(x, /, *, axis=0)[source]#
沿着某个轴解开数组。
JAX 实现
array_api.unstack()
。- 参数:
x (ArrayLike) – 要解开的数组。必须满足
x.ndim >= 1
。axis (int) – 解开数组的轴。必须满足
-x.ndim <= axis < x.ndim
。
- 返回值:
解开的数组的元组。
- 返回类型:
参见
jax.numpy.stack()
:unstack
的逆操作jax.numpy.split()
:沿指定轴将数组拆分为批次。
示例
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> arrs = jnp.unstack(arr) >>> print(*arrs) [1 2 3] [4 5 6]
stack()
提供了此函数的反操作>>> jnp.stack(arrs) Array([[1, 2, 3], [4, 5, 6]], dtype=int32)