jax.numpy.vdot#

jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[源代码]#

执行两个一维向量的共轭乘法。

numpy.vdot() 的 JAX 实现。

参数:
  • a (ArrayLike) – 第一个输入数组,如果不是 1D,则会被展平。

  • b (类数组) – 第二个输入数组,如果不是一维数组,则会被展平。必须满足 a.size == b.size

  • precision (精度类) – 可以是 None (默认),表示使用后端的默认精度;也可以是一个 Precision 枚举值 ( Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST );或者是一个包含两个这种值的元组,分别表示 ab 的精度。

  • preferred_element_type (DTypeLike | None) – 可以是 None (默认),表示使用输入类型的默认累加类型;也可以是一个数据类型,表示将结果累加到该数据类型并返回该数据类型的结果。

返回值:

包含输入共轭向量积的标量数组(形状为 ())。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)

注意此函数与 dot() 的区别,当输入为复数时,dot() 不会对第一个输入进行共轭操作

>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)