jax.experimental.pallas.mosaic_gpu.TransposeTransform#

class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[源代码]#

转置分块的内存引用。

参数:

permutation (tuple[int, ...])

__init__(permutation)#
参数:

permutation (tuple[int, ...])

返回类型:

None

方法

__init__(permutation)

batch(leading_rank)

返回一个变换,该变换接受具有额外 leading_rank 维度的引用。

to_gpu_transform()

undo(ref)

属性

permutation