jax.experimental.pallas.mosaic_gpu.wgmma#
- jax.experimental.pallas.mosaic_gpu.wgmma(acc, a, b)[源代码]#
在给定的引用上执行异步 warp 组矩阵乘法累加。
从概念上讲,这等同于执行
acc[...] += a[...] @ b[...]
,只是计算是异步执行的。- 参数:
acc (gpu_core.WGMMAAbstractAccumulatorRef) – 累加器引用。需要通过调用
jax.experimental.pallas.run_scoped()
并传入一个jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef()
来分配。a – 左侧操作数引用。
b (pallas_core.TransformedRef) – 右侧操作数引用。
- 返回类型:
None