jax.numpy.repeat

内容

jax.numpy.repeat#

jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[source]#

从重复元素构造数组。

JAX 实现 numpy.repeat()

参数:
  • a (ArrayLike) – N 维数组

  • repeats (ArrayLike) – 指定重复次数的 1D 整数数组。必须与重复轴的长度匹配。

  • axis (int | None | None) – 指定 a 沿其构造重复数组的轴的整数。如果为 None(默认值),则 a 首先被展平。

  • total_repeat_length (int | None | None) – 为了使 jnp.repeatjit() 和其他 JAX 变换兼容,此参数必须静态指定。如果 sum(repeats) 大于指定的 total_repeat_length,则剩余的值将被丢弃。如果 sum(repeats) 小于 total_repeat_length,则将重复最后一个值。

返回:

a 的重复值构成的数组。

返回类型:

数组

另请参阅

示例

沿最后一个轴重复每个值两次

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果未指定 axis,则输入数组将被展平

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

将数组传递给 repeats 以不同次数重复每个值

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

为了在 jit 和其他 JAX 变换中使用 repeat,必须使用 total_repeat_length 静态指定输出的大小

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

如果 total_repeat_length 小于 sum(repeats),则结果将被截断

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果它更大,则其他条目将填充最后一个值

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)