jax.numpy.polydiv

内容

jax.numpy.polydiv#

jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[source]#

返回多项式除法的商和余数。

JAX 实现 numpy.polydiv().

参数:
  • u (ArrayLike) – 被除数多项式系数的数组。

  • v (ArrayLike) – 除数多项式系数的数组。

  • trim_leading_zeros (bool) – 默认值为 False。如果为 True,则删除返回值中的前导零以匹配 numpy 的结果。但是,这会阻止该函数在编译代码中使用。由于浮点运算误差累积的差异,被视为零的值的截止值可能会导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间的结果不一致。当 trim_leading_zeros=True 时,结果可能会导致输出形状不一致。

返回值:

商和余数数组的元组。输出的 dtype 始终被提升为非精确。

返回类型:

元组[数组, 数组]

注意

jax.numpy.polydiv() 仅接受数组作为输入,与 numpy.polydiv() 不同,后者也接受标量输入。

另请参阅

示例

>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25  , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25  , 1.4375], dtype=float32), Array([0.    , 0.    , 7.5625], dtype=float32))

如果 trim_leading_zeros=True,结果将与 np.polydiv 的结果匹配。

>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25  , 1.4375], dtype=float32), Array([7.5625], dtype=float32))