GPU 性能提示#

本文档重点介绍神经网络工作负载的性能提示

Matmul 精度#

在最近的 GPU 世代中,例如 Nvidia A100 世代或更高版本,以 bfloat16 精度执行大多数计算可能是一个好主意。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 也有一个关于 NVIDIA XLA 性能标志 的页面。

XLA 标志的存在和确切行为可能依赖于 jaxlib 版本。

截至 jaxlib==0.4.18(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。一些与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。

其中一些可能会在未来的版本中默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量设置。例如,我们可以将其添加到 Python 文件的顶部

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

有关更多示例,另请参阅 在 Nvidia GPU 上为 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_triton_gemm_any 对其支持的任何 GEMM(矩阵乘法)使用基于 Triton 的 GEMM 发射器。默认值为 False。

通信标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地将异步通信与计算重叠。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志用作应用于可用总内存的乘数,创建一个阈值,指导延迟隐藏调度器 (LHS) 平衡内存减少和延迟隐藏优化。默认值为 95。

    此因子有效地为编译器传递建立了一个内存限制,确定调度程序何时应优先考虑

    1. 内存减少:当内存使用量接近或超过计算出的阈值时。

    2. 延迟隐藏:当内存使用量低于阈值时,允许更积极的优化,这可能会暂时增加内存使用量,但会提高整体性能。

    通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。

  • –xla_gpu_enable_pipelined_collectives 使用流水线并行时,此标志启用将第 (i+1) 层权重 AllGather 与第 i 层计算重叠。它还启用将第 (i+1) 层权重 Reduce/ReduceScatter 与第 i 层的计算重叠。默认值为 False。启用此标志时存在一些错误。

  • –xla_gpu_collective_permute_decomposer_threshold 此标志在执行 GSPMD 流水线 时很有用。设置非零阈值会将 CollectivePermute 分解为 CollectivePermuteReceiveDoneCollectivePermuteSendDone 对,以便可以在每个对应的 ReceiveDone/SendDone 对之间执行计算,从而实现更多的重叠。默认情况下,阈值为 0,并且没有分解。将其设置为阈值 > 0,例如 --xla_gpu_collective_permute_decomposer_threshold=1024 可以启用此功能。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的 AllGather/ReduceScatter/AllReduce 组合成一个大的 AllGather/ReduceScatter/AllReduce,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载上的 AllGather/ReduceScatter 阈值,考虑将它们调整得足够高,以便至少组合 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

NCCL 标志#

这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上单主机多设备计算很有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志似乎对多主机通信没有用。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点一个进程。在某些情况下,这可以加速 jit 编译的计算。jax.distributed.initialize() API 在 SLURM 下运行时会自动理解该配置。但是,这只是一条经验法则,测试每个 GPU 一个进程和每个节点一个进程在您的用例中可能很有用。