jax.scipy.fft.dct#
- jax.scipy.fft.dct(x, type=2, n=None, axis=-1, norm=None)[source]#
计算输入的离散余弦变换
scipy.fft.dct()
的 JAX 实现。- 参数:
- 返回:
包含 x 的离散余弦变换的数组
- 返回类型:
另请参阅
jax.scipy.fft.dctn()
: 多维 DCTjax.scipy.fft.idct()
: 逆 DCTjax.scipy.fft.idctn()
: 多维逆 DCT
示例
>>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x)) [[-0.58 -0.33 -1.08] [-0.88 -1.01 -1.79] [-1.06 -2.43 1.24]]
当
n
小于x.shape[axis]
时>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2)) [[-0.22 -0.9 ] [-0.57 -1.68] [-2.52 -0.11]]
当
n
小于x.shape[axis]
且axis=0
时>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2, axis=0)) [[-2.22 1.43 -0.67] [ 0.52 -0.26 -0.04]]
当
n
大于x.shape[axis]
且axis=1
时>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=4, axis=1)) [[-0.58 -0.35 -0.64 -1.11] [-0.88 -0.9 -1.46 -1.68] [-1.06 -2.25 -1.15 1.93]]