jax.numpy.tile#
- jax.numpy.tile(A, reps)[源代码]#
通过沿指定维度重复
A
来构造数组。numpy.tile()
的 JAX 实现。如果
A
是形状为(d1, d2, ..., dn)
的数组,并且reps
是一个整数序列,则生成的数组的形状为(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)
,其中A
沿每个维度平铺。- 参数:
A (ArrayLike) – 要重复的输入数组。可以是任何形状或维度。
reps (DimSize | Sequence[DimSize]) – 指定沿每个轴的重复次数。
- 返回:
一个新的数组,其中输入数组已根据
reps
进行重复。- 返回类型:
另请参阅
jax.numpy.repeat()
:从重复的元素构造一个数组。jax.numpy.broadcast_to()
:将数组广播到指定的形状。
示例
>>> arr = jnp.array([1, 2]) >>> jnp.tile(arr, 2) Array([1, 2, 1, 2], dtype=int32) >>> arr = jnp.array([[1, 2], ... [3, 4,]]) >>> jnp.tile(arr, (2, 1)) Array([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=int32)