jax.numpy.repeat#

jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[源代码]#

从重复的元素构造数组。

numpy.repeat() 的 JAX 实现。

参数:
  • a (类数组) – N 维数组

  • repeats (类数组) – 1D 整数数组,指定重复次数。必须与重复轴的长度匹配。

  • axis (int | None | None) – 整数,指定沿 a 构造重复数组的轴。如果为 None(默认值),则首先将 a 展平。

  • total_repeat_length (int | None | None) – 为了使 jnp.repeatjit() 和其他 JAX 转换兼容,必须静态指定此值。如果 sum(repeats) 大于指定的 total_repeat_length,则剩余的值将被丢弃。如果 sum(repeats) 小于 total_repeat_length,则最后一个值将被重复。

返回:

一个由 a 的重复值构造的数组。

返回类型:

数组

另请参阅

示例

沿最后一个轴重复每个值两次

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果未指定 axis,则输入数组将被展平

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

传递一个数组给 repeats,以便每个值重复不同的次数

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

为了在 jit 和其他 JAX 转换中使用 repeat,必须使用 total_repeat_length 静态指定输出的大小

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

如果 total_repeat_length 小于 sum(repeats),则结果将被截断

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果它更大,则额外的条目将用最后一个值填充

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)