jax.experimental.pallas.mosaic_gpu 模块#

Pallas 针对 H100 的实验性 GPU 后端。

这些 API 非常不稳定,可能会每周更改。使用风险自负。

#

Barrier(num_arrivals[, num_barriers])

GPUBlockSpec([block_shape, index_map, ...])

GPUCompilerParams(*[, approx_math, ...])

Mosaic GPU 编译器参数。

GPUMemorySpace(value)

一个枚举。

Layout(值)

一个枚举。

SwizzleTransform(交错)

TilingTransform(平铺)

表示内存引用的平铺转换。

TransposeTransform(置换)

转置平铺的内存引用。

WGMMAAccumulatorRef(形状, 数据类型, _初始化)

函数#

barrier_arrive(屏障)

到达给定的屏障。

barrier_wait(屏障)

等待给定的屏障。

commit_smem()

提交对 SMEM 的所有写入,使其对加载、TMA 和 WGMMA 可见。

copy_gmem_to_smem(源, 目标, 屏障)

异步将 GMEM 引用复制到 SMEM 引用。

copy_smem_to_gmem(源, 目标[, 谓词])

异步将 SMEM 引用复制到 GMEM 引用。

emit_pipeline(主体, *, 网格[, 输入规范, ...])

创建一个函数,在 Pallas 内核中发出手动流水线。

layout_cast(x, 新布局)

转换给定数组的布局。

set_max_registers(n, *, 操作)

设置 warp 拥有的最大寄存器数。

wait_smem_to_gmem(n[, wait_read_only])

等待直到没有超过 n 个 SMEM->GMEM 复制在进行中。

wgmma(累加器, a, b)

在给定引用上执行异步 warp 组矩阵乘法累加。

wgmma_wait(n)

等待直到没有超过 n 个 WGMMA 操作在进行中。

别名#

ACC

别名 of WGMMAAccumulatorRef

GMEM

别名 of jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM

SMEM

别名 of jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM