jax.lax.batch_matmul

jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]#

批量矩阵乘法。

参数::
  • lhs (Array)

  • rhs (Array)

  • precision (PrecisionLike | None)

返回值类型::

Array