jax.numpy.tile#

jax.numpy.tile(A, reps)[源代码]#

通过沿指定维度重复 A 来构造数组。

numpy.tile() 的 JAX 实现。

如果 A 是形状为 (d1, d2, ..., dn) 的数组,并且 reps 是一个整数序列,则生成的数组的形状为 (reps[0] * d1, reps[1] * d2, ..., reps[n] * dn),其中 A 沿每个维度平铺。

参数:
  • A (ArrayLike) – 要重复的输入数组。可以是任何形状或维度。

  • reps (DimSize | Sequence[DimSize]) – 指定沿每个轴的重复次数。

返回:

一个新的数组,其中输入数组已根据 reps 进行重复。

返回类型:

数组

另请参阅

示例

>>> arr = jnp.array([1, 2])
>>> jnp.tile(arr, 2)
Array([1, 2, 1, 2], dtype=int32)
>>> arr = jnp.array([[1, 2],
...                  [3, 4,]])
>>> jnp.tile(arr, (2, 1))
Array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)