jax.numpy.polydiv#

jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[源代码]#

返回多项式除法的商和余数。

numpy.polydiv() 的 JAX 实现。

参数:
  • u (ArrayLike) – 被除数多项式系数的数组。

  • v (ArrayLike) – 除数多项式系数的数组。

  • trim_leading_zeros (bool) – 默认值为 False。如果为 True,则删除返回值中的前导零,以匹配 numpy 的结果。但这会阻止该函数在编译代码中使用。由于浮点算术误差累积的差异,被视为零的值的截止值可能导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间的结果不一致。当 trim_leading_zeros=True 时,结果可能会导致输出形状不一致。

返回:

商和余数数组的元组。输出的 dtype 始终会被提升为非精确类型。

返回类型:

tuple[Array, Array]

注意

jax.numpy.polydiv() 仅接受数组作为输入,不像 numpy.polydiv() 也接受标量输入。

另请参阅

示例

>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25  , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25  , 1.4375], dtype=float32), Array([0.    , 0.    , 7.5625], dtype=float32))

如果 trim_leading_zeros=True,则结果与 np.polydiv 的结果匹配。

>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25  , 1.4375], dtype=float32), Array([7.5625], dtype=float32))