jax.numpy.fft.fftn

内容

jax.numpy.fft.fftn#

jax.numpy.fft.fftn(a, s=None, axes=None, norm=None)[source]#

沿给定轴计算多维离散傅立叶变换。

JAX 实现 numpy.fft.fftn().

参数:
  • a (ArrayLike) – 输入数组

  • s (Shape | None | None) – 整数序列。指定结果的形状。如果未指定,它将默认为 a 沿指定的 axes 的形状。

  • axes (Sequence[int] | None | None) – 整数序列,默认为 None。指定计算变换的轴。

  • norm (str | None | None) – 字符串。归一化模式。“backward”、“ortho” 和 “forward” 受支持。

返回:

包含 a 的多维离散傅里叶变换的数组。

返回类型:

数组

参见

示例

jnp.fft.fftn 默认情况下沿所有轴计算变换,当 axes 参数为 None 时。

>>> x = jnp.array([[1, 2, 5, 6],
...                [4, 1, 3, 7],
...                [5, 9, 2, 1]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.fftn(x)
Array([[ 46.  +0.j  ,   0.  +2.j  ,  -6.  +0.j  ,   0.  -2.j  ],
       [ -2.  +1.73j,   6.12+6.73j,   0.  -1.73j, -18.12-3.27j],
       [ -2.  -1.73j, -18.12+3.27j,   0.  +1.73j,   6.12-6.73j]],      dtype=complex64)

s=[2] 时,沿 axis -1 的变换维度将为 2,而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2]))
[[ 3.+0.j -1.+0.j]
 [ 5.+0.j  3.+0.j]
 [14.+0.j -4.+0.j]]

s=[2]axes=[0] 时,沿 axis 0 的变换维度将为 2,而沿其他轴的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2], axes=[0]))
[[ 5.+0.j  3.+0.j  8.+0.j 13.+0.j]
 [-3.+0.j  1.+0.j  2.+0.j -1.+0.j]]

s=[2, 3] 时,变换的形状将为 (2, 3)

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.numpy.fft.fftn(x, s=[2, 3]))
[[16. +0.j   -0.5+4.33j -0.5-4.33j]
 [ 0. +0.j   -4.5+0.87j -4.5-0.87j]]

jnp.fft.ifftn 可用于从 jnp.fft.fftn 的结果中重建 x

>>> x_fftn = jnp.fft.fftn(x)
>>> jnp.allclose(x, jnp.fft.ifftn(x_fftn))
Array(True, dtype=bool)