jax.numpy.polymul

内容

jax.numpy.polymul#

jax.numpy.polymul(a1, a2, *, trim_leading_zeros=False)[source]#

返回两个多项式的乘积。

JAX 实现 numpy.polymul()

参数:
  • a1 (ArrayLike) – 多项式系数的 1D 数组。

  • a2 (ArrayLike) – 多项式系数的 1D 数组。

  • trim_leading_zeros (bool) – 默认为 False。如果 True,则删除返回值中的前导零以匹配 numpy 的结果。但会阻止该函数在编译代码中使用。由于浮点运算误差累积的差异,将值视为零的阈值可能导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间产生不一致的结果。当 trim_leading_zeros=True 时,结果可能导致输出形状不一致。

返回:

两个多项式乘积的系数数组。输出的 dtype 始终提升为不精确类型。

返回类型:

Array

注意

jax.numpy.polymul() 仅接受数组作为输入,这与 numpy.polymul() 不同,后者也接受标量输入。

参见

示例

>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
array([10,  5,  6,  3,  0])
>>> jnp.polymul(x1, x2)
Array([ 0., 10.,  5.,  6.,  3.,  0.], dtype=float32)

如果 trim_leading_zeros=True,则结果与 np.polymul 的结果匹配。

>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
Array([10.,  5.,  6.,  3.,  0.], dtype=float32)

对于数据类型为 complex 的输入数组

>>> x3 = np.array([2., 1+2j, 1-2j])
>>> x4 = np.array([0, 5, 0, 3])
>>> np.polymul(x3, x4)
array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j])
>>> jnp.polymul(x3, x4)
Array([ 0. +0.j, 10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j],      dtype=complex64)
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j], dtype=complex64)