jax.numpy.polydiv#
- jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[source]#
返回多项式除法的商和余数。
JAX 实现
numpy.polydiv()
.- 参数:
u (ArrayLike) – 被除数多项式系数的数组。
v (ArrayLike) – 除数多项式系数的数组。
trim_leading_zeros (bool) – 默认值为
False
。如果为True
,则删除返回值中的前导零以匹配 numpy 的结果。但是,这会阻止该函数在编译代码中使用。由于浮点运算误差累积的差异,被视为零的值的截止值可能会导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间的结果不一致。当trim_leading_zeros=True
时,结果可能会导致输出形状不一致。
- 返回值:
商和余数数组的元组。输出的 dtype 始终被提升为非精确。
- 返回类型:
注意
jax.numpy.polydiv()
仅接受数组作为输入,与numpy.polydiv()
不同,后者也接受标量输入。另请参阅
jax.numpy.polyadd()
: 计算两个多项式的和。jax.numpy.polysub()
: 计算两个多项式的差。jax.numpy.polymul()
: 计算两个多项式的积。
示例
>>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) (array([1.25 , 1.4375]), array([7.5625])) >>> jnp.polydiv(x1, x2) (Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
如果
trim_leading_zeros=True
,结果将与np.polydiv
的结果匹配。>>> jnp.polydiv(x1, x2, trim_leading_zeros=True) (Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))