jax.numpy.concatenate#
- jax.numpy.concatenate(arrays, axis=0, dtype=None)[源代码]#
沿现有轴连接数组。
JAX 实现的
numpy.concatenate()
。- 参数:
- 返回值:
拼接后的结果。
- 返回类型:
另请参阅
jax.lax.concatenate()
: XLA 拼接 API。jax.numpy.concat()
: 此函数的 Array API 版本。jax.numpy.stack()
: 沿新轴拼接数组。
示例
一维拼接
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concatenate([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
二维拼接
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concatenate([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)