jax.numpy.polyval

内容

jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)[source]#

在特定值处计算多项式。

JAX 实现 numpy.polyval().

对于长度为 M 的一维多项式系数 p,该函数返回以下值:

\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
参数:
  • p (ArrayLike) – 形状为 (M,) 的多项式系数数组。

  • x (ArrayLike) – 一个数字或一个数字数组。

  • unroll (int) – 用于控制使用 lax.scan 展开的步数。它必须在静态中指定。

返回值:

x 形状相同的数组。

返回类型:

数组

注意

参数 unroll 是 JAX 特定的。它不会影响正确性,但会对评估高阶多项式的性能产生重大影响。该参数控制 jnp.polyval 实现中 lax.scan 的展开步骤数。考虑设置 unroll=128(甚至更高)以提高加速器上的运行时性能,但会增加编译时间。

另请参阅

示例

>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)

如果 x 是一个二维数组,polyval 返回与 x 形状相同的二维数组。

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19.,   8.,  76.],
       [ 34.,  53., 134.],
       [  8.,  34.,  76.]], dtype=float32)