XLA 编译器标志列表#
简介#
本指南简要概述了 XLA 以及 XLA 如何与 Jax 相关。有关深入的详细信息,请参阅 XLA 文档。然后,它列出了常用的 XLA 编译器标志,旨在优化 Jax 程序的性能。
XLA: Jax 背后的强大引擎#
XLA (加速线性代数) 是一种用于线性代数的特定领域编译器,在 Jax 的性能和灵活性方面起着关键作用。它使 Jax 能够通过将您的 Python/NumPy 式代码转换并编译成高效的机器指令,为各种硬件后端(CPU、GPU、TPU)生成优化代码。
Jax 使用 XLA 的 JIT 编译功能,在运行时将您的 Python 函数转换为优化的 XLA 计算。
在 Jax 中配置 XLA:#
您可以通过在运行 Python 脚本或 colab notebook 之前设置 XLA_FLAGS 环境变量来影响 Jax 中 XLA 的行为。
对于 colab notebook
使用 os.environ['XLA_FLAGS']
提供标志
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
对于 Python 脚本
将 XLA_FLAGS
指定为命令行的一部分
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
重要提示
在导入 Jax 或其他相关库之前设置
XLA_FLAGS
。在后端初始化后更改XLA_FLAGS
将不会生效,并且鉴于后端初始化时间未明确定义,通常在执行任何 Jax 代码之前设置XLA_FLAGS
更安全。尝试使用不同的标志来优化您特定用例的性能。
更多信息
有关 XLA 的完整和最新的文档,请参阅官方 XLA 文档。
对于 OpenXLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值在 xla/debug_options_flags.cc 中定义,完整的标志列表可以在这里找到。
TPU 编译器标志不是 OpenXLA 的一部分,但下面列出了常用的选项。
请注意,此标志列表并不详尽,并且可能会发生变化。这些标志是实现细节,不保证它们将保持可用或保持其当前行为。
常用 XLA 标志#
标志 |
类型 |
备注 |
---|---|---|
|
字符串 (文件路径) |
预优化 HLO 文件和其他工件将放置的文件夹(请参阅 XLA 工具)。 |
|
三态标志 (true/false/auto) |
将所有 collective-permute 操作重写为其异步变体。当设置为 |
|
三态标志 (true/false/auto) |
如果设置为 true,则启用异步 all gather。如果为 |
|
字符串(逗号分隔的传递名称列表) |
要禁用的 HLO 传递的逗号分隔列表。这些名称必须与传递名称完全匹配(逗号周围没有空格)。 |
TPU XLA 标志#
标志 |
类型 |
备注 |
---|---|---|
|
布尔值(true/false) |
用于增加 DCN(数据中心网络) all-reduces 的重叠机会的优化,用于数据并行分片。 |
|
布尔值(true/false) |
即使数据并行操作的输出大小与堆叠变量中可以就地保存的大小不匹配,也可以跨多个迭代流水线化数据并行操作。可能会增加内存压力。 |
|
布尔值(true/false) |
启用传递,将异步集合通信与在它们的 -start 和 -done 指令之间调度的计算操作(输出/循环融合或卷积)融合在一起。 |
|
三态标志 (true/false/auto) |
启用在 AsyncCollectiveFusion 传递中融合 all-gathers。 |
|
布尔值(true/false) |
启用在 AsyncCollectiveFusion 传递中的多个步骤(融合)中继续相同的异步集合。 |
|
布尔值(true/false) |
启用在单个 TensorCore 上计算和通信的重叠,即 MegaCore 融合的一个核心等效项。 |
|
布尔值(true/false) |
是否以分区方式运行 RngBitGenerator HLO,如果计算的不同部分的 shards 之间需要确定性结果,则这种方式是不安全的。 |
|
布尔值(true/false) |
允许将 all-gathers 与卷积/all-reduces 融合。 |
|
布尔值(true/false) |
通过扫描循环向后流水线化 all-gathers(目前是 megascale all-gathers)。 |
GPU XLA 标志#
标志 |
类型 |
备注 |
---|---|---|
|
布尔值(true/false) |
此标志启用延迟隐藏调度程序,以有效地将异步通信与计算重叠。默认值为 False。 |
|
布尔值(true/false) |
使用基于 Triton 的矩阵乘法。 |
|
标志 (0-3) |
用于设置 GPU 图级别的旧版标志。在新用例中使用 xla_gpu_enable_command_buffer。0 = 关闭;1 = 捕获融合和 memcpys;2 = 捕获 gemms;3 = 捕获卷积。 |
|
整数(字节) |
这些标志用于调整何时将多个小的 AllGather / ReduceScatter / AllReduce 合并为一个大的 AllGather / ReduceScatter / AllReduce,以减少在跨设备通信上花费的时间。例如,对于基于 Transformer 的工作负载上的 AllGather / ReduceScatter 阈值,请考虑将它们调整得足够高,以便至少合并一个 Transformer 层的权重 AllGather / ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。 |
|
整数(字节) |
请参阅上面的 xla_gpu_all_reduce_combine_threshold_bytes。 |
|
整数(字节) |
请参阅上面的 xla_gpu_all_reduce_combine_threshold_bytes。 |
|
布尔值(true/false) |
启用 all-gather 指令的流水线化。 |
|
布尔值(true/false) |
启用 reduce-scatter 指令的流水线化。 |
|
布尔值(true/false) |
启用 all-reduce 指令的流水线化。 |
|
布尔值(true/false) |
启用 while 循环的双缓冲。 |
|
布尔值(true/false) |
合并具有相同收集维度的 all-gather 操作,或者合并与维度无关的操作。 |
|
布尔值(true/false) |
合并具有相同维度的 reduce-scatter 操作,或者合并与维度无关的操作。 |
延伸阅读