XLA 编译器标志列表#

简介#

本指南简要概述了 XLA 以及 XLA 与 Jax 的关系。有关深入的详细信息,请参阅 XLA 文档。然后,它列出了常用的 XLA 编译器标志,旨在优化 Jax 程序的性能。

XLA:Jax 背后的动力#

XLA(加速线性代数)是一个用于线性代数的特定领域编译器,在 Jax 的性能和灵活性方面发挥着关键作用。它通过将您的 Python/NumPy 类代码转换为高效的机器指令,使 Jax 能够为各种硬件后端(CPU、GPU、TPU)生成优化的代码。

Jax 使用 XLA 的 JIT 编译功能在运行时将您的 Python 函数转换为优化的 XLA 计算。

在 Jax 中配置 XLA:#

您可以在运行 Python 脚本或 colab notebook 之前,通过设置 XLA_FLAGS 环境变量来影响 XLA 在 Jax 中的行为。

对于 colab notebook

使用 os.environ['XLA_FLAGS'] 提供标志

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

对于 python 脚本

在 cli 命令中指定 XLA_FLAGS

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要提示

  • 在导入 Jax 或其他相关库之前设置 XLA_FLAGS。在后端初始化后更改 XLA_FLAGS 将不起作用,并且鉴于后端初始化时间未明确定义,通常在执行任何 Jax 代码之前设置 XLA_FLAGS 更安全。

  • 尝试不同的标志以优化特定用例的性能。

更多信息

  • 有关 XLA 的完整和最新的文档,请参阅官方 XLA 文档

  • 对于 XLA 的开源版本(CPU、GPU)支持的后端,XLA 标志及其默认值在 xla/debug_options_flags.cc 中定义,并且可以在 这里 找到完整的标志列表。

  • TPU 编译器标志不是 OpenXLA 的一部分,但下面列出了常用的选项。

  • 请注意,此标志列表并不详尽,可能会发生变化。这些标志是实现细节,不能保证它们将保持可用或维持其当前行为。

常用的 XLA 标志#

标志

类型

说明

xla_dump_to

字符串(文件路径)

放置预优化 HLO 文件和其他工件的文件夹(请参阅 XLA 工具)。

xla_enable_async_collective_permute

三态标志(true/false/auto)

将所有集体置换操作重写为其异步变体。当设置为 auto 时,XLA 可以根据其他配置或条件自动启用异步集体操作。

xla_enable_async_all_gather

三态标志(true/false/auto)

如果设置为 true,则启用异步 all gather。如果为 auto,则仅为实现异步 all-gather 的平台启用。该实现(例如 BC 卸载或连续融合)根据其他标志值选择。

xla_disable_hlo_passes

字符串(以逗号分隔的 pass 名称列表)

要禁用的 HLO pass 的逗号分隔列表。这些名称必须与 pass 名称完全匹配(逗号周围没有空格)。

TPU XLA 标志#

标志

类型

说明

xla_tpu_enable_data_parallel_all_reduce_opt

布尔值(true/false)

优化,以增加用于数据并行分片的 DCN(数据中心网络)all-reduces 的重叠机会。

xla_tpu_data_parallel_opt_different_sized_ops

布尔值(true/false)

即使数据并行操作的输出大小与堆叠变量中可就地保存的大小不匹配,也可以在多个迭代中启用数据并行操作的流水线处理。可能会增加内存压力。

xla_tpu_enable_async_collective_fusion

布尔值(true/false)

启用 pass,将异步集体通信与在其 -start 和 -done 指令之间调度的计算操作(输出/循环融合或卷积)融合。

xla_tpu_enable_async_collective_fusion_fuse_all_gather

三态标志(true/false/auto)

启用在 AsyncCollectiveFusion pass 内融合 all-gathers。
如果设置为 auto,则将根据目标启用。

xla_tpu_enable_async_collective_fusion_multiple_steps

布尔值(true/false)

启用在 AsyncCollectiveFusion pass 中在多个步骤(融合)中继续相同的异步集体操作。

xla_tpu_overlap_compute_collective_tc

布尔值(true/false)

在单个 TensorCore 上启用计算和通信的重叠,即,MegaCore 融合的一个核心等效项。

xla_tpu_spmd_rng_bit_generator_unsafe

布尔值(true/false)

是否以分区方式运行 RngBitGenerator HLO,如果对计算的不同部分上的不同分片期望确定性结果,则该方式是不安全的。

xla_tpu_megacore_fusion_allow_ags

布尔值(true/false)

允许将 all-gathers 与卷积/all-reduces 融合。

xla_tpu_enable_ag_backward_pipelining

布尔值(true/false)

通过扫描循环向后流水线化 all-gathers(目前为 megascale all-gathers)。

GPU XLA 标志#

标志

类型

说明

xla_gpu_enable_latency_hiding_scheduler

布尔值(true/false)

此标志启用延迟隐藏调度器,以高效地将异步通信与计算重叠。默认值为 False。

xla_gpu_enable_triton_gemm

布尔值(true/false)

使用基于 Triton 的矩阵乘法。

xla_gpu_graph_level

标志(0-3)

用于设置 GPU 图级别的旧标志。在新用例中使用 xla_gpu_enable_command_buffer。0 = 关闭;1 = 捕获融合和 memcpys;2 = 捕获 gemms;3 = 捕获卷积。

xla_gpu_all_reduce_combine_threshold_bytes

整数(字节)

这些标志调整何时将多个小的 AllGather/ReduceScatter/AllReduce 合并为一个大的 AllGather/ReduceScatter/AllReduce,以减少在跨设备通信上花费的时间。例如,对于基于 Transformer 的工作负载上的 AllGather/ReduceScatter 阈值,请考虑将其调整得足够高,以便至少合并一个 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

xla_gpu_all_gather_combine_threshold_bytes

整数(字节)

请参阅上面的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_reduce_scatter_combine_threshold_bytes

整数(字节)

请参阅上面的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_enable_pipelined_all_gather

布尔值(true/false)

启用 all-gather 指令的流水线处理。

xla_gpu_enable_pipelined_reduce_scatter

布尔值(true/false)

启用 reduce-scatter 指令的流水线处理。

xla_gpu_enable_pipelined_all_reduce

布尔值(true/false)

启用 all-reduce 指令的流水线处理。

xla_gpu_enable_while_loop_double_buffering

布尔值(true/false)

为 while 循环启用双缓冲。

xla_gpu_enable_triton_softmax_fusion

布尔值(true/false)

使用基于 Triton 的 Softmax 融合。

xla_gpu_enable_all_gather_combine_by_dim

布尔值(true/false)

合并具有相同收集维度或与其维度无关的 all-gather 操作。

xla_gpu_enable_reduce_scatter_combine_by_dim

布尔值(true/false)

合并具有相同维度或与其维度无关的 reduce-scatter 操作。

其他阅读材料