jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, out_type=None)[源代码]#

通用的点积/缩并运算符。

封装了 XLA 的 DotGeneral 运算符。

dot_general 的语义很复杂,但大多数用户不应该直接使用它。相反,您可以使用更高级的函数,如 jax.numpy.dot()jax.numpy.matmul()jax.numpy.tensordot()jax.numpy.einsum() 等,这些函数将在底层构建对 dot_general 的适当调用。如果您真的想了解 dot_general 本身,我们建议您阅读 XLA 的 DotGeneral 运算符文档。

参数:
  • lhs (ArrayLike) – 一个数组

  • rhs (ArrayLike) – 一个数组

  • dimension_numbers (DotDimensionNumbers) – 一个由整数序列组成的元组的元组,形式为 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • precision (PrecisionLike | None) –

    可选。此参数控制计算的数值精度,它可以是以下之一

    • None,表示当前后端的默认精度,

    • 一个 Precision 枚举值,或一个由两个 Precision 枚举值组成的元组,分别表示 lhsrhs 的精度,或者

    • 一个 DotAlgorithm 或一个 DotAlgorithmPreset,表示必须用于累加点积的算法。

  • preferred_element_type (DTypeLike | None | None) – 可选。此参数控制点积输出的数据类型。默认情况下,此操作的输出元素类型将根据通常的类型提升规则与 lhsrhs 输入元素类型匹配。将 preferred_element_type 设置为特定的 dtype 将意味着操作返回该元素类型。当 precision 不是 DotAlgorithmDotAlgorithmPreset 时,preferred_element_type 为编译器提供一个提示,以使用此数据类型累加点积。

返回:

一个数组,其第一个维度是(共享的)批处理维度,后跟 lhs 的非缩并/非批处理维度,最后是 rhs 的非缩并/非批处理维度。

返回类型:

数组