jax.scipy.fft.dctn

内容

jax.scipy.fft.dctn#

jax.scipy.fft.dctn(x, type=2, s=None, axes=None, norm=None)[source]#

计算输入的多维离散余弦变换

JAX 实现 scipy.fft.dctn().

参数:
  • x (Array) – 数组

  • type (int) – 整数,默认值为 2。目前仅支持类型 2。

  • s (Sequence[int] | None | None) – 整数或整数序列。指定结果的形状。如果未指定,则默认为 x 沿指定 axes 的形状。

  • axes (Sequence[int] | None | None) – 整数或整数序列。指定将沿其计算变换的轴。

  • norm (str | None | None) – 字符串。归一化模式:[None, "backward", "ortho"] 中的一个。默认值为 None,等效于 "backward"

返回:

包含 x 的离散余弦变换的数组

返回类型:

数组

另请参阅

示例

jax.scipy.fft.dctnaxes 参数为 None 时,默认情况下沿两个轴计算变换。

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x))
[[-5.04 -7.54 -3.26]
 [ 0.83  3.64 -4.03]
 [ 0.12 -0.73  3.74]]

s=[2] 时,沿 axis 0 的变换维度将为 2,而沿 axis 1 的维度将与输入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2]))
[[-2.92 -2.68 -5.74]
 [ 0.42  0.97  1.  ]]

s=[2]axes=[1] 时,沿 axis 1 的变换维度将为 2,而沿 axis 0 的维度将与输入相同。此外,当 axes=[1] 时,变换将仅沿 axis 1 计算。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2], axes=[1]))
[[-0.22 -0.9 ]
 [-0.57 -1.68]
 [-2.52 -0.11]]

s=[2, 4] 时,变换的形状将为 (2, 4)

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dctn(x, s=[2, 4]))
[[-2.92 -2.49 -4.21 -5.57]
 [ 0.42  0.79  1.16  0.8 ]]