jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#
一般的点积/收缩运算符。
包装 XLA 的 DotGeneral 运算符。
dot_general
的语义很复杂,但大多数用户不需要直接使用它。 相反,您可以使用更高级的函数,如jax.numpy.dot()
,jax.numpy.matmul()
,jax.numpy.tensordot()
,jax.numpy.einsum()
以及其他函数,它们将在幕后构建对dot_general
的适当调用。 如果您真的想了解dot_general
本身,我们建议您阅读 XLA 的 DotGeneral 运算符文档。- 参数:
lhs (ArrayLike) – 数组
rhs (ArrayLike) – 数组
dimension_numbers (DotDimensionNumbers) – 一个元组,包含形式为
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
的整数序列元组precision (PrecisionLike | None) – 可选。要么为
None
,表示使用后端的默认精度,要么为Precision
枚举值(Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
),或者为两个Precision
枚举值的元组,分别表示lhs`
和rhs
的精度。preferred_element_type (DTypeLike | None | None) – 可选。要么为
None
,表示使用输入类型的默认累加类型,要么为数据类型,表示将结果累加到该数据类型并返回具有该数据类型的结果。
- 返回值:
一个数组,其第一个维度是(共享的)批次维度,之后是
lhs
的非收缩/非批次维度,最后是rhs
的非收缩/非批次维度。- 返回类型: