jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[源代码]#
在特定值处计算多项式的值。
numpy.polyval()
的 JAX 实现。对于长度为
M
的一维多项式系数p
,该函数返回的值为\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- 参数:
p (ArrayLike) – 形状为
(M,)
的多项式系数数组。x (ArrayLike) – 一个数字或数字数组。
unroll (int) – 用于控制
lax.scan
中展开步数的数字。必须静态指定。
- 返回:
与
x
形状相同的数组。- 返回类型:
注意
unroll
参数是 JAX 特有的。它不影响正确性,但可能对评估高阶多项式的性能产生重大影响。该参数控制jnp.polyval
实现内部lax.scan
的展开步数。考虑设置unroll=128
(甚至更高)以提高加速器上的运行时性能,代价是增加编译时间。另请参阅
jax.numpy.polyfit()
:最小二乘多项式拟合。jax.numpy.poly()
:查找具有给定根的多项式的系数。jax.numpy.roots()
:计算给定系数的多项式的根。
示例
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
如果
x
是一个二维数组,则polyval
返回一个与x
形状相同的二维数组>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)