jax.lax.dot_general

目录

jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#

一般的点积/收缩运算符。

包装 XLA 的 DotGeneral 运算符。

dot_general 的语义很复杂,但大多数用户不需要直接使用它。 相反,您可以使用更高级的函数,如 jax.numpy.dot()jax.numpy.matmul()jax.numpy.tensordot()jax.numpy.einsum() 以及其他函数,它们将在幕后构建对 dot_general 的适当调用。 如果您真的想了解 dot_general 本身,我们建议您阅读 XLA 的 DotGeneral 运算符文档。

参数:
  • lhs (ArrayLike) – 数组

  • rhs (ArrayLike) – 数组

  • dimension_numbers (DotDimensionNumbers) – 一个元组,包含形式为 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) 的整数序列元组

  • precision (PrecisionLike | None) – 可选。要么为 None,表示使用后端的默认精度,要么为 Precision 枚举值(Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST),或者为两个 Precision 枚举值的元组,分别表示 lhs`rhs 的精度。

  • preferred_element_type (DTypeLike | None | None) – 可选。要么为 None,表示使用输入类型的默认累加类型,要么为数据类型,表示将结果累加到该数据类型并返回具有该数据类型的结果。

返回值:

一个数组,其第一个维度是(共享的)批次维度,之后是 lhs 的非收缩/非批次维度,最后是 rhs 的非收缩/非批次维度。

返回类型:

数组