jax.experimental.pallas.triton.TritonCompilerParams# class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[源代码]# Triton 的编译器参数。 参数: num_warps (int | None) num_stages (int | None) serialized_metadata (bytes | None) num_warps# 内核使用的 warp 数量。每个 warp 由 32 个线程组成。 类型: int | None num_stages# 编译器应该用于软件流水循环的阶段数。 类型: int | None serialized_metadata# 附加的编译器元数据。此字段不稳定,未来可能会被删除。 类型: bytes | None __init__(num_warps=None, num_stages=None, serialized_metadata=None)# 参数: num_warps (int | None | None) num_stages (int | None | None) serialized_metadata (bytes | None | None) 返回类型: None 方法 __init__([num_warps, num_stages, ...]) 属性 PLATFORM num_stages num_warps serialized_metadata