GPU 性能提示#

本文档重点介绍神经网络工作负载的性能提示

矩阵乘法精度#

在最近的 GPU 世代(例如 Nvidia A100 世代或更高版本)上,以 bfloat16 精度执行大多数计算可能是一个好主意。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能标志 的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18 (发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。一些标志与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些标志则与每个设备上的代码生成有关。

未来版本中,其中一些标志可能会默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量设置。例如,我们可以将其添加到 Python 文件的顶部:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

有关更多示例,另请参阅 为 Nvidia GPU 上的 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_triton_gemm_any 对它支持的任何 GEMM(矩阵乘法)使用基于 Triton 的 GEMM 发射器。默认值为 False。

通信标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地将异步通信与计算重叠。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志充当应用于可用总内存的乘数,从而创建一个阈值,该阈值指导延迟隐藏调度器 (LHS) 在内存减少和延迟隐藏优化之间取得平衡。默认值为 95。

    此因子有效地为编译器传递建立了一个内存限制,确定调度器何时应优先考虑

    1. 内存减少:当内存使用量接近或超过计算的阈值时。

    2. 延迟隐藏:当内存使用量低于阈值时,允许更积极的优化,这些优化可能会暂时增加内存使用量,但可以提高整体性能。

    通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。

  • –xla_gpu_enable_pipelined_collectives 当使用流水线并行时,此标志启用将第 (i+1) 层的权重 AllGather 与第 i 层的计算重叠。它还启用将第 (i+1) 层的权重 Reduce/ReduceScatter 与第 i 层的计算重叠。默认值为 False。启用此标志时存在一些错误。

  • –xla_gpu_collective_permute_decomposer_threshold 此标志在执行 GSPMD 流水线时很有用。设置非零阈值会将 CollectivePermute 分解为 CollectivePermuteReceiveDoneCollectivePermuteSendDone 对,以便可以在每个对应的 ReceiveDone/SendDone 对之间执行计算,从而实现更多重叠。默认情况下,阈值为 0,并且没有分解。将其设置为大于 0 的阈值,例如 --xla_gpu_collective_permute_decomposer_threshold=1024,可以启用此功能。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的 AllGather/ReduceScatter/AllReduce 合并为一个大的 AllGather/ReduceScatter/AllReduce,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载上的 AllGather/ReduceScatter 阈值,请考虑将其调整得足够高,以便至少合并 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

NCCL 标志#

这些 Nvidia NCCL 标志值可能对于在 Nvidia GPU 上进行单主机多设备计算很有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志似乎对多主机通信没有用处。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点一个进程。在某些情况下,这可以加快即时编译的计算速度。jax.distributed.initialize() API 在 SLURM 下运行时会自动理解该配置。但是,这只是一条经验法则,测试每个 GPU 一个进程和每个节点一个进程对于您的用例可能都很有用。