jax.experimental.sparse.bcoo_dot_general_sampled

jax.experimental.sparse.bcoo_dot_general_sampled#

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)[source]#

在给定稀疏索引处计算输出的收缩运算。

参数:
  • lhs – NumPy 数组。

  • rhs – NumPy 数组。

  • indices (Array) – BCOO 索引。

  • dimension_numbers (DotDimensionNumbers) – 形如 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) 的元组。

  • A (Array)

  • B (Array)

返回值:

BCOO 数据,一个包含结果的 NumPy 数组。

返回类型:

Array