jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[source]#
从重复元素构造数组。
JAX 实现
numpy.repeat()
。- 参数:
a (ArrayLike) – N 维数组
repeats (ArrayLike) – 指定重复次数的 1D 整数数组。必须与重复轴的长度匹配。
axis (int | None | None) – 指定
a
沿其构造重复数组的轴的整数。如果为 None(默认值),则a
首先被展平。total_repeat_length (int | None | None) – 为了使
jnp.repeat
与jit()
和其他 JAX 变换兼容,此参数必须静态指定。如果sum(repeats)
大于指定的total_repeat_length
,则剩余的值将被丢弃。如果sum(repeats)
小于total_repeat_length
,则将重复最后一个值。
- 返回:
由
a
的重复值构成的数组。- 返回类型:
另请参阅
jax.numpy.tile()
: 重复整个数组而不是单个值。
示例
沿最后一个轴重复每个值两次
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果未指定
axis
,则输入数组将被展平>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
将数组传递给
repeats
以不同次数重复每个值>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
为了在
jit
和其他 JAX 变换中使用repeat
,必须使用total_repeat_length
静态指定输出的大小>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
如果 total_repeat_length 小于
sum(repeats)
,则结果将被截断>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果它更大,则其他条目将填充最后一个值
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)