jax.numpy.concat#
- jax.numpy.concat(arrays, /, *, axis=0)[源代码]#
沿现有轴连接数组。
array_api.concat()
的 JAX 实现。- 参数:
arrays (Sequence[ArrayLike]) – 要连接的数组序列;每个数组必须具有相同的形状,但指定的轴除外。 如果给定了单个数组,则它将被视为等同于 arrays = unstack(arrays),但实现将避免显式解堆栈。
axis (int | None) – 指定沿其连接的轴。
- 返回:
连接的结果。
- 返回类型:
另请参阅
jax.lax.concatenate()
:XLA 连接 API。jax.numpy.concatenate()
:此函数的 NumPy 版本。jax.numpy.stack()
:沿新轴连接数组。
示例
一维连接
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concat([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
二维连接
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concat([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)