GPU 性能提示#
本文档重点介绍神经网络工作负载的性能提示
Matmul 精度#
在最近的 GPU 世代中,例如 Nvidia A100 世代或更高版本,以 bfloat16
精度执行大多数计算可能是一个好主意。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
实例化 Dense
层。以下是一些代码示例
在 Flax LM1B 示例中,
Dense
模块使用 可配置的 dtype 实例化,该 dtype 默认为 bfloat16。在 MaxText 中,
DenseGeneral
模块也 使用可配置的 dtype 实例化,该 dtype 默认为 bfloat16。
XLA 性能标志#
注意
JAX-Toolbox 也有一个关于 NVIDIA XLA 性能标志 的页面。
XLA 标志的存在和确切行为可能依赖于 jaxlib
版本。
截至 jaxlib==0.4.18
(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。一些与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。
其中一些可能会在未来的版本中默认设置。
这些标志可以通过 XLA_FLAGS
shell 环境变量设置。例如,我们可以将其添加到 Python 文件的顶部
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_latency_hiding_scheduler=true '
)
有关更多示例,另请参阅 在 Nvidia GPU 上为 Pax 训练推荐的 XLA 标志。
代码生成标志#
–xla_gpu_triton_gemm_any 对其支持的任何 GEMM(矩阵乘法)使用基于 Triton 的 GEMM 发射器。默认值为 False。
通信标志#
–xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地将异步通信与计算重叠。默认值为 False。
–xla_gpu_memory_limit_slop_factor 此标志用作应用于可用总内存的乘数,创建一个阈值,指导延迟隐藏调度器 (LHS) 平衡内存减少和延迟隐藏优化。默认值为 95。
此因子有效地为编译器传递建立了一个内存限制,确定调度程序何时应优先考虑
内存减少:当内存使用量接近或超过计算出的阈值时。
延迟隐藏:当内存使用量低于阈值时,允许更积极的优化,这可能会暂时增加内存使用量,但会提高整体性能。
通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。
–xla_gpu_enable_pipelined_collectives 使用流水线并行时,此标志启用将第 (i+1) 层权重
AllGather
与第 i 层计算重叠。它还启用将第 (i+1) 层权重Reduce
/ReduceScatter
与第 i 层的计算重叠。默认值为 False。启用此标志时存在一些错误。–xla_gpu_collective_permute_decomposer_threshold 此标志在执行 GSPMD 流水线 时很有用。设置非零阈值会将
CollectivePermute
分解为CollectivePermuteReceiveDone
和CollectivePermuteSendDone
对,以便可以在每个对应的ReceiveDone
/SendDone
对之间执行计算,从而实现更多的重叠。默认情况下,阈值为 0,并且没有分解。将其设置为阈值 > 0,例如--xla_gpu_collective_permute_decomposer_threshold=1024
可以启用此功能。–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的
AllGather
/ReduceScatter
/AllReduce
组合成一个大的AllGather
/ReduceScatter
/AllReduce
,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载上的AllGather
/ReduceScatter
阈值,考虑将它们调整得足够高,以便至少组合 Transformer 层的权重AllGather
/ReduceScatter
。默认情况下,combine_threshold_bytes
设置为 256。
NCCL 标志#
这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上单主机多设备计算很有用
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
这些 NCCL 标志可以提高单主机通信速度。这些标志似乎对多主机通信没有用。
多进程#
我们建议每个 GPU 使用一个进程,而不是每个节点一个进程。在某些情况下,这可以加速 jit 编译的计算。jax.distributed.initialize()
API 在 SLURM 下运行时会自动理解该配置。但是,这只是一条经验法则,测试每个 GPU 一个进程和每个节点一个进程在您的用例中可能很有用。