jax.numpy.convolve

内容

jax.numpy.convolve#

jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[source]#

两个一维数组的卷积。

JAX 实现 numpy.convolve()

一维数组的卷积定义为

\[c_k = \sum_j a_{k - j} v_j\]
参数:
  • a (ArrayLike) – 卷积的左侧输入。必须有 a.ndim == 1

  • v (ArrayLike) – 卷积的右侧输入。必须有 v.ndim == 1

  • mode (str) –

    控制输出的大小。可用操作为

    • "full": (默认) 输出输入的完整卷积。

    • "same": 返回 "full" 输出的中心部分,其大小与 a 相同。

    • "valid":返回"full"输出中不依赖于数组边缘填充的部分。

  • precision (PrecisionLike) – 指定计算的精度。有关可用值的说明,请参阅jax.lax.Precision

  • preferred_element_type (DTypeLike | None) – 数据类型,指示将结果累积到并返回具有该数据类型的结果。默认为None,这意味着输入类型的默认累积类型。

返回:

包含卷积结果的数组。

返回类型:

数组

另请参阅

示例

一些一维卷积示例

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 1, 2])

jax.numpy.convolve默认情况下返回使用隐式零填充边缘的完整卷积

>>> jnp.convolve(x, y)
Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

指定mode = 'same'返回与第一个输入大小相同的中心卷积

>>> jnp.convolve(x, y, mode='same')
Array([ 9., 16., 15., 12.,  5.], dtype=float32)

指定mode = 'valid'仅返回两个数组完全重叠的部分

>>> jnp.convolve(x, y, mode='valid')
Array([16., 15., 12.], dtype=float32)

对于复数值输入

>>> x1 = jnp.array([3+1j, 2, 4-3j])
>>> y1 = jnp.array([1, 2-3j, 4+5j])
>>> jnp.convolve(x1, y1)
Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)