jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0)[源代码]#

创建一个函数,用于在 Pallas 内核中发出手动流水线。

参数:
  • body (Callable[..., None]) – 流水线主体。

  • grid (pallas_core.StaticGrid) – 用于流水线的网格。

  • in_specs (Sequence[pallas_core.BlockSpec]) – 输入的块规格。

  • out_specs (Sequence[pallas_core.BlockSpec]) – 输出的块规格。

  • max_concurrent_steps (int) – 同时处于活动状态的连续阶段的最大数量。默认为 1。

  • delay_release (int) – 重用输入/输出引用之前等待的步数。默认为 0,并且必须严格小于 max_concurrent_steps。一般来说,如果不在主体中等待 WGMMA,则需要将其设置为 1。