jax.numpy.concatenate#

jax.numpy.concatenate(arrays, axis=0, dtype=None)[源代码]#

沿现有轴连接数组。

JAX 实现的 numpy.concatenate()

参数:
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – 要连接的数组序列;除了指定的轴之外,每个数组的形状都必须相同。如果给出一个数组,它将被等效地视为 arrays = unstack(arrays),但该实现将避免显式解堆叠。

  • axis (int | None) – 指定沿哪个轴进行拼接。

  • dtype (DTypeLike | None | None) – 可选的生成数组的数据类型。如果未指定,数据类型将通过 类型提升语义 中描述的类型提升规则确定。

返回值:

拼接后的结果。

返回类型:

Array

另请参阅

示例

一维拼接

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concatenate([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

二维拼接

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concatenate([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)