jax.scipy.fft.idct

内容

jax.scipy.fft.idct#

jax.scipy.fft.idct(x, type=2, n=None, axis=-1, norm=None)[source]#

计算输入的逆离散余弦变换

JAX 实现 scipy.fft.idct().

参数:
  • x (Array) – 数组

  • type (int) – 整数,默认值 = 2。当前仅支持类型 2。

  • n (int | None | None) – 整数,默认值 = x.shape[axis]。变换的长度。如果大于 x.shape[axis],输入将被零填充,如果小于,输入将被截断。

  • axis (int) – 整数,默认值 = -1。执行 dct 的轴。

  • norm (str | None | None) – 字符串。归一化模式:[None, "backward", "ortho"]之一。默认值为None,等效于"backward"

返回值:

包含 x 的反离散余弦变换的数组

返回类型:

数组

参见

示例

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idct(x))
[[-0.02 -0.   -0.17]
 [-0.02 -0.07 -0.28]
 [-0.16 -0.36  0.18]]

n小于x.shape[axis]

>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idct(x, n=2))
[[ 0.   -0.19]
 [-0.03 -0.34]
 [-0.38  0.04]]

n小于x.shape[axis]axis=0

>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idct(x, n=2, axis=0))
[[-0.35  0.23 -0.1 ]
 [ 0.17 -0.09  0.01]]

n大于x.shape[axis]axis=0

>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idct(x, n=4, axis=0))
[[-0.34  0.03  0.07]
 [ 0.    0.18 -0.17]
 [ 0.14  0.09 -0.14]
 [ 0.   -0.18  0.14]]

jax.scipy.fft.idct 可用于从 jax.scipy.fft.dct 的结果中重建 x

>>> x_dct = jax.scipy.fft.dct(x)
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
Array(True, dtype=bool)