jax.numpy.sinc#

jax.numpy.sinc(x, /)[源代码]#

计算归一化的 sinc 函数。

JAX 中 numpy.sinc() 的实现。

归一化的 sinc 函数由下式给出:

\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]

其中 sinc(0) 返回极限值 1。sinc 函数是光滑且无限可微的。

参数:

x (类数组) – 输入数组;将会被提升为非精确类型。

返回:

一个与 x 形状相同的数组,包含计算结果。

返回类型:

数组

示例

>>> x = jnp.array([-1, -0.5, 0, 0.5, 1])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sinc(x)
Array([-0.   ,  0.637,  1.   ,  0.637, -0.   ], dtype=float32)

将其与计算该函数的朴素方法进行比较,该方法在零处未定义

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sin(jnp.pi * x) / (jnp.pi * x)
Array([-0.   ,  0.637,    nan,  0.637, -0.   ], dtype=float32)

JAX 为 sinc 定义了一个自定义梯度规则,以允许即使对于高阶导数也能准确评估零处的梯度

>>> f = jnp.sinc
>>> for i in range(1, 6):
...   f = jax.grad(f)
...   print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}")
...
(d/dx)^1 f(0.0) = 0.00
(d/dx)^2 f(0.0) = -3.29
(d/dx)^3 f(0.0) = 0.00
(d/dx)^4 f(0.0) = 19.48
(d/dx)^5 f(0.0) = 0.00