GPU 性能提示#
本文档重点介绍神经网络工作负载的性能提示
矩阵乘法精度#
在最近几代 GPU(例如 Nvidia A100 及更高版本)上,最好以 bfloat16
精度执行大多数计算。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
实例化 Dense
层。以下是一些代码示例
在 Flax LM1B 示例 中,
Dense
模块 使用可配置的 dtype 实例化,该 dtype 默认为 bfloat16。在 MaxText 中,
DenseGeneral
模块也 使用可配置的 dtype 实例化,该 dtype 默认为 bfloat16。
XLA 性能标志#
注意
JAX-Toolbox 还有一页关于 NVIDIA XLA 性能标志 的内容。
XLA 标志的存在和确切行为可能取决于 jaxlib
版本。
截至 jaxlib==0.4.18
(于 2023 年 10 月 6 日 发布),设置这些 XLA 标志可以提高性能。有些与 GPU 之间的通信有关,因此仅在多个设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。
其中一些可能在将来的版本中默认设置。
这些标志可以通过 XLA_FLAGS
shell 环境变量设置。例如,我们可以在 Python 文件的开头添加它
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
有关更多示例,另请参阅 在 Nvidia GPU 上进行 Pax 训练的推荐 XLA 标志。
代码生成标志#
–xla_gpu_enable_triton_softmax_fusion 此标志启用基于模式匹配并由 Triton 代码生成的自动 softmax 融合。默认值为 False。
–xla_gpu_triton_gemm_any 对其支持的任何 GEMM(矩阵乘法)发射器使用基于 Triton 的 GEMM。默认值为 False。
通信标志#
–xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度程序,以有效地将异步通信与计算重叠。默认值为 False。
–xla_gpu_enable_pipelined_collectives 使用流水线并行时,此标志启用将第 (i+1) 层权重
AllGather
与第 i 层计算重叠。它还启用将第 (i+1) 层权重Reduce
/ReduceScatter
与第 i 层的计算重叠。默认值为 False。**启用此标志时存在一些错误。**–xla_gpu_collective_permute_decomposer_threshold 当执行 GSPMD 流水线 时,此标志很有用。设置非零阈值会将
CollectivePermute
分解成CollectivePermuteReceiveDone
和CollectivePermuteSendDone
对,以便可以在每个对应的ReceiveDone
/SendDone
对之间执行计算,从而实现更多重叠。默认情况下,阈值为 0,并且没有分解。将其设置为大于 0 的阈值,例如--xla_gpu_collective_permute_decomposer_threshold=1024
可以启用此功能。–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的
AllGather
/ReduceScatter
/AllReduce
组合成一个大的AllGather
/ReduceScatter
/AllReduce
以减少跨设备通信所花费的时间。例如,对于基于 Transformer 的工作负载上的AllGather
/ReduceScatter
阈值,请考虑将它们调整到足够高,以便至少组合 Transformer 层的权重AllGather
/ReduceScatter
。默认情况下,combine_threshold_bytes
设置为 256。
NCCL 标志#
这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上的单主机多设备计算很有用
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
这些 NCCL 标志可以提高单主机通信速度。这些标志似乎尚未用于多主机通信。
多进程#
我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加快 jitted 计算速度。当在 SLURM 下运行时,jax.distributed.initialize()
API 会自动理解该配置。但是,这只是一个经验法则,在您的用例中测试每个 GPU 一个进程和每个节点一个进程可能会有用。