jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[source]#
在特定值处计算多项式。
JAX 实现
numpy.polyval()
.对于长度为
M
的一维多项式系数p
,该函数返回以下值:\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- 参数:
p (ArrayLike) – 形状为
(M,)
的多项式系数数组。x (ArrayLike) – 一个数字或一个数字数组。
unroll (int) – 用于控制使用
lax.scan
展开的步数。它必须在静态中指定。
- 返回值:
与
x
形状相同的数组。- 返回类型:
注意
参数
unroll
是 JAX 特定的。它不会影响正确性,但会对评估高阶多项式的性能产生重大影响。该参数控制jnp.polyval
实现中lax.scan
的展开步骤数。考虑设置unroll=128
(甚至更高)以提高加速器上的运行时性能,但会增加编译时间。另请参阅
jax.numpy.polyfit()
: 最小二乘多项式拟合。jax.numpy.poly()
: 找到具有给定根的多项式的系数。jax.numpy.roots()
: 计算给定系数的多项式的根。
示例
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
如果
x
是一个二维数组,polyval
返回与x
形状相同的二维数组。>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)