jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)[源代码]#

在特定值处计算多项式的值。

numpy.polyval()的 JAX 实现。

对于长度为 M 的一维多项式系数 p,该函数返回的值为

\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
参数:
  • p (ArrayLike) – 形状为 (M,) 的多项式系数数组。

  • x (ArrayLike) – 一个数字或数字数组。

  • unroll (int) – 用于控制 lax.scan 中展开步数的数字。必须静态指定。

返回:

x 形状相同的数组。

返回类型:

Array

注意

unroll 参数是 JAX 特有的。它不影响正确性,但可能对评估高阶多项式的性能产生重大影响。该参数控制 jnp.polyval 实现内部 lax.scan 的展开步数。考虑设置 unroll=128(甚至更高)以提高加速器上的运行时性能,代价是增加编译时间。

另请参阅

示例

>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)

如果 x 是一个二维数组,则 polyval 返回一个与 x 形状相同的二维数组

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19.,   8.,  76.],
       [ 34.,  53., 134.],
       [  8.,  34.,  76.]], dtype=float32)