GPU 性能提示#

本文档重点介绍神经网络工作负载的性能提示

矩阵乘法精度#

在最近几代 GPU(例如 Nvidia A100 及更高版本)上,最好以 bfloat16 精度执行大多数计算。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 还有一页关于 NVIDIA XLA 性能标志 的内容。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18(于 2023 年 10 月 6 日 发布),设置这些 XLA 标志可以提高性能。有些与 GPU 之间的通信有关,因此仅在多个设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。

其中一些可能在将来的版本中默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量设置。例如,我们可以在 Python 文件的开头添加它

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

有关更多示例,另请参阅 在 Nvidia GPU 上进行 Pax 训练的推荐 XLA 标志

代码生成标志#

  • –xla_gpu_enable_triton_softmax_fusion 此标志启用基于模式匹配并由 Triton 代码生成的自动 softmax 融合。默认值为 False。

  • –xla_gpu_triton_gemm_any 对其支持的任何 GEMM(矩阵乘法)发射器使用基于 Triton 的 GEMM。默认值为 False。

通信标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度程序,以有效地将异步通信与计算重叠。默认值为 False。

  • –xla_gpu_enable_pipelined_collectives 使用流水线并行时,此标志启用将第 (i+1) 层权重 AllGather 与第 i 层计算重叠。它还启用将第 (i+1) 层权重 Reduce/ReduceScatter 与第 i 层的计算重叠。默认值为 False。**启用此标志时存在一些错误。**

  • –xla_gpu_collective_permute_decomposer_threshold 当执行 GSPMD 流水线 时,此标志很有用。设置非零阈值会将 CollectivePermute 分解成 CollectivePermuteReceiveDoneCollectivePermuteSendDone 对,以便可以在每个对应的 ReceiveDone/SendDone 对之间执行计算,从而实现更多重叠。默认情况下,阈值为 0,并且没有分解。将其设置为大于 0 的阈值,例如 --xla_gpu_collective_permute_decomposer_threshold=1024 可以启用此功能。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的 AllGather/ReduceScatter/AllReduce 组合成一个大的 AllGather/ReduceScatter/AllReduce 以减少跨设备通信所花费的时间。例如,对于基于 Transformer 的工作负载上的 AllGather/ReduceScatter 阈值,请考虑将它们调整到足够高,以便至少组合 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

NCCL 标志#

这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上的单主机多设备计算很有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志似乎尚未用于多主机通信。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加快 jitted 计算速度。当在 SLURM 下运行时,jax.distributed.initialize() API 会自动理解该配置。但是,这只是一个经验法则,在您的用例中测试每个 GPU 一个进程和每个节点一个进程可能会有用。