jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[源代码]#
从重复的元素构造数组。
numpy.repeat()
的 JAX 实现。- 参数:
a (类数组) – N 维数组
repeats (类数组) – 1D 整数数组,指定重复次数。必须与重复轴的长度匹配。
axis (int | None | None) – 整数,指定沿
a
构造重复数组的轴。如果为 None(默认值),则首先将a
展平。total_repeat_length (int | None | None) – 为了使
jnp.repeat
与jit()
和其他 JAX 转换兼容,必须静态指定此值。如果sum(repeats)
大于指定的total_repeat_length
,则剩余的值将被丢弃。如果sum(repeats)
小于total_repeat_length
,则最后一个值将被重复。
- 返回:
一个由
a
的重复值构造的数组。- 返回类型:
另请参阅
jax.numpy.tile()
: 重复整个数组而不是单个值。
示例
沿最后一个轴重复每个值两次
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果未指定
axis
,则输入数组将被展平>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
传递一个数组给
repeats
,以便每个值重复不同的次数>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
为了在
jit
和其他 JAX 转换中使用repeat
,必须使用total_repeat_length
静态指定输出的大小>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
如果 total_repeat_length 小于
sum(repeats)
,则结果将被截断>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果它更大,则额外的条目将用最后一个值填充
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)