jax.numpy.convolve#

jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[源代码]#

两个一维数组的卷积。

JAX 实现的 numpy.convolve()

一维数组的卷积定义为

\[c_k = \sum_j a_{k - j} v_j\]
参数:
  • a (ArrayLike) – 卷积的左侧输入。必须有 a.ndim == 1

  • v (ArrayLike) – 卷积的右侧输入。必须有 v.ndim == 1

  • mode (str) –

    控制输出的大小。可用的操作有

    • "full": (默认) 输出输入的完整卷积。

    • "same": 返回 "full" 输出的中心部分,该部分的大小与 a 相同。

    • "valid": 返回 "full" 输出中不依赖于数组边缘填充的部分。

  • precision (PrecisionLike) – 指定计算的精度。有关可用值的描述,请参阅 jax.lax.Precision

  • preferred_element_type (DTypeLike | None) – 一种数据类型,指示将结果累积到该数据类型并返回该数据类型的结果。默认为 None,这意味着输入类型的默认累积类型。

返回值:

包含卷积结果的数组。

返回类型:

数组

另请参阅

示例

一些一维卷积示例

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 1, 2])

jax.numpy.convolve 默认情况下使用边缘的隐式零填充返回完整卷积

>>> jnp.convolve(x, y)
Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

指定 mode = 'same' 返回与第一个输入大小相同的中心卷积

>>> jnp.convolve(x, y, mode='same')
Array([ 9., 16., 15., 12.,  5.], dtype=float32)

指定 mode = 'valid' 仅返回两个数组完全重叠的部分

>>> jnp.convolve(x, y, mode='valid')
Array([16., 15., 12.], dtype=float32)

对于复数值输入

>>> x1 = jnp.array([3+1j, 2, 4-3j])
>>> y1 = jnp.array([1, 2-3j, 4+5j])
>>> jnp.convolve(x1, y1)
Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)