jax.numpy.polyint#
- jax.numpy.polyint(p, m=1, k=None)[源代码]#
返回指定阶数多项式积分的系数。
numpy.polyint()
的 JAX 实现。- 参数:
p (类数组) – 多项式系数数组。
m (int) – 积分阶数。默认值为 1。它必须静态指定。
k (int | ArrayLike | None) – 标量或
m
个积分常数的数组。
- 返回值:
积分多项式的系数数组。
- 返回类型:
另请参阅
jax.numpy.polyder()
:计算多项式导数的系数。jax.numpy.polyval()
:计算多项式在特定值上的值。
示例
多项式 \(12 x^2 + 12 x + 6\) 的一阶积分是 \(4 x^3 + 6 x^2 + 6 x\)。
>>> p = jnp.array([12, 12, 6]) >>> jnp.polyint(p) Array([4., 6., 6., 0.], dtype=float32)
由于没有提供常数
k
,结果在末尾包含了0
。如果提供了常数k
>>> jnp.polyint(p, k=4) Array([4., 6., 6., 4.], dtype=float32)
则二阶积分是 \(x^4 + 2 x^3 + 3 x\)
>>> jnp.polyint(p, m=2) Array([1., 2., 3., 0., 0.], dtype=float32)
当
m>=2
时,常数k
应以包含m
个元素的数组形式提供。多项式 \(12 x^2 + 12 x + 6\) 的二阶积分,常数为k=[4, 5]
,结果为 \(x^4 + 2 x^3 + 3 x^2 + 4 x + 5\)>>> jnp.polyint(p, m=2, k=jnp.array([4, 5])) Array([1., 2., 3., 4., 5.], dtype=float32)