jax.experimental.sparse.bcsr_dot_general

jax.experimental.sparse.bcsr_dot_general#

jax.experimental.sparse.bcsr_dot_general(lhs, rhs, *, dimension_numbers, precision=None, preferred_element_type=None)[source]#

一般的收缩运算。

参数:
  • lhs (BCSR | Array) – ndarray 或 BCSR 格式的稀疏数组。

  • rhs (Array) – ndarray 或 BCSR 格式的稀疏数组。

  • dimension_numbers (DotDimensionNumbers) – 形如 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) 的元组。

  • precision (None | None) – 未使用

  • preferred_element_type (None | None) – 未使用

返回:

包含结果的 ndarray 或 BCSR 格式稀疏数组。如果两个输入都是稀疏的,则结果将是稀疏的,类型为 BCSR。如果任一输入是密集的,则结果将是密集的,类型为 ndarray。

返回类型:

数组