jax.experimental.sparse.bcoo_dot_general_sampled#

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)[源代码]#

在给定稀疏索引处计算输出的收缩操作。

参数:
  • lhs – 一个 ndarray。

  • rhs – 一个 ndarray。

  • indices (Array) – BCOO 索引。

  • dimension_numbers (DotDimensionNumbers) – 形式为 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) 的元组的元组。

  • A (数组)

  • B (数组)

返回:

BCOO 数据,一个包含结果的 ndarray。

返回类型:

数组