jax.experimental.pallas.mosaic_gpu.TilingTransform#
- class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[源代码]#
表示内存引用的平铺转换。
在形状为 (M, N) 的数组上对 (X, Y) 进行平铺将导致 (M // X, N // Y, X, Y) 的转换形状。例如。 使用 (64, 32) 的平铺平铺的 (256, 256) 块将被平铺为 (4, 8, 64, 32)。
方法
__init__
(tiling)batch
(leading_rank)返回一个转换,该转换接受一个具有额外 leading_rank 维度的引用。
to_gpu_transform
()undo
(ref)属性
tiling