jax.numpy.fft.fft2

内容

jax.numpy.fft.fft2#

jax.numpy.fft.fft2(a, s=None, axes=(-2, -1), norm=None)[source]#

在给定轴上计算二维离散傅立叶变换。

JAX 实现 numpy.fft.fft2().

参数:
  • a (ArrayLike) – 输入数组。必须有 a.ndim >= 2.

  • s (形状 | | ) – 可选的长度为 2 的整数序列。指定沿每个指定轴的输出大小。如果未指定,则默认情况下为沿指定 axesa 的大小。

  • axes (序列[整数]) – 可选的长度为 2 的整数序列,默认值为 (-2, -1)。指定计算变换的轴。

  • norm (字符串 | | ) – 字符串,默认值为“backward”。归一化模式。“backward”、“ortho”和“forward”受支持。

返回:

包含沿给定 axesa 的二维离散傅里叶变换的数组。

返回类型:

数组

另请参阅

示例

jnp.fft.fft2 默认情况下沿最后两个轴计算变换。

>>> x = jnp.array([[[1, 3],
...                 [2, 4]],
...                [[5, 7],
...                 [6, 8]]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.fft2(x)
Array([[[10.+0.j, -4.+0.j],
        [-2.+0.j,  0.+0.j]],

       [[26.+0.j, -4.+0.j],
        [-2.+0.j,  0.+0.j]]], dtype=complex64)

s=[2, 3] 时,沿 axes (-2, -1) 的变换维度将为 (2, 3),而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.fft2(x, s=[2, 3])
Array([[[10.  +0.j  , -0.5 -6.06j, -0.5 +6.06j],
        [-2.  +0.j  , -0.5 +0.87j, -0.5 -0.87j]],

       [[26.  +0.j  ,  3.5-12.99j,  3.5+12.99j],
        [-2.  +0.j  , -0.5 +0.87j, -0.5 -0.87j]]], dtype=complex64)

s=[2, 3]axes=(0, 1) 时,沿 axes (0, 1) 的变换形状将为 (2, 3),而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.fft2(x, s=[2, 3], axes=(0, 1))
Array([[[14. +0.j  , 22. +0.j  ],
        [ 2. -6.93j,  4.-10.39j],
        [ 2. +6.93j,  4.+10.39j]],

       [[-8. +0.j  , -8. +0.j  ],
        [-2. +3.46j, -2. +3.46j],
        [-2. -3.46j, -2. -3.46j]]], dtype=complex64)

jnp.fft.ifft2 可用于从 jnp.fft.fft2 的结果中重建 x

>>> x_fft2 = jnp.fft.fft2(x)
>>> jnp.allclose(x, jnp.fft.ifft2(x_fft2))
Array(True, dtype=bool)