jax.lax.dot#

jax.lax.dot(lhs, rhs, precision=None, preferred_element_type=None)[源代码]#

向量/向量,矩阵/向量和矩阵/矩阵乘法。

包装 XLA 的 Dot 运算符。

有关更一般的收缩,请参阅 jax.lax.dot_general() 运算符。

参数:
  • lhs (Array) – 维度为 1 或 2 的数组。

  • rhs (Array) – 维度为 1 或 2 的数组。

  • precision (PrecisionLike | None) –

    可选。此参数控制计算的数值,可以是以下之一

    • None,表示当前后端的默认精度,

    • 一个 Precision 枚举值,或者一个包含两个 Precision 枚举值的元组,分别表示 lhs`rhs 的精度,或者

    • 一个 DotAlgorithm 或一个 DotAlgorithmPreset,指示必须用于累积点积的算法。

  • preferred_element_type (DTypeLike | None | None) – 可选。此参数控制点积输出的数据类型。默认情况下,此操作的输出元素类型将根据通常的类型提升规则与 lhsrhs 输入元素类型匹配。将 preferred_element_type 设置为特定的 dtype 将意味着该操作返回该元素类型。当 precision 不是 DotAlgorithmDotAlgorithmPreset 时,preferred_element_type 为编译器提供一个提示,指示使用此数据类型累积点积。

返回:

包含乘积的数组。

返回类型:

数组