jax.experimental.pallas.mosaic_gpu.wgmma#
- jax.experimental.pallas.mosaic_gpu.wgmma(acc, a, b)[源代码]#
对给定的引用执行异步 warp group matmul-accumulate 操作。
从概念上讲,这等效于执行
acc[...] += a[...] @ b[...]
,只是计算是异步执行的。- 参数:
acc (gpu_core.WGMMAAbstractAccumulatorRef) – 累加器引用。需要通过调用
jax.experimental.pallas.run_scoped()
,并使用jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef()
进行分配。a – 左侧操作数引用。
b (pallas_core.TransformedRef) – 右侧操作数引用。
- 返回类型:
None