jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[源代码]#

批量矩阵乘法。

参数:
  • lhs (Array)

  • rhs (Array)

  • precision (PrecisionLike | None)

返回类型:

数组