调查回归#

所以您更新了 JAX,然后遇到了速度回归?您有一点时间并准备调查这个问题?我们首先要创建一个 JAX 问题。但是,如果您能查明触发回归的提交,那就对我们有很大的帮助。

本文档解释了我们如何识别导致 15% 性能回归 的提交。

步骤#

如果复制器足够快,这可以轻松完成。这是一种蛮力方法,而不是二分法,但如果复制器足够快,它效果很好。这确保您始终测试兼容的 XLA 和 JAX 提交。它还限制了 XLA 重新编译。

以下是一个建议的调查策略

  1. 您可以对两个版本之间的夜间容器进行蛮力测试。

  2. 每小时重新编译,同时保持 XLA 和 JAX 同步。

  3. 最终验证:也许是对几个提交的手动检查(或 git 二分法)。

夜间调查#

这可以通过使用 NVIDIA JAX-Toolbox 夜间容器 完成。

  • 在某些日子里,错误会阻止容器构建,或者会出现暂时的回归。只需丢弃那些日子即可。

  • 因此,您应该会得到回归发生的特定日期或几天。

  • 要实现自动化,您需要两个 Python 脚本

    • test_runner.sh:将启动容器和测试。

    • test.sh:将安装缺少的依赖项并运行测试

以下是在该问题中使用的真实示例脚本:https://github.com/google/jax/issues/17686

  • test_runner.sh

  for m in 7 8 9; do
    for d in `seq -w 1 30`; do
      docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
    done
  Done
  • test.sh

  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  git clone https://github.com/Autodesk/XLB
  cd XLB
  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

然后,您可以 grep 每个输出以查看回归何时发生:grep MLUPS OUT*。以下是我们得到的结果

OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25:  File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634

这发现 8-24 很好,但 8-26 很糟糕。在 8-25,另一个问题阻止了获取结果。因此,我们需要在 8-24 和 8-26 之间每小时进行调查。早些时候出现了一个较小的减速,在本示例中我们忽略它。这只是在这些日期之间进行另一次每小时调查。

每小时调查#

该脚本每小时在两个日期之间执行一次 JAX 和 XLA 的检出,重新构建所有内容并运行测试。脚本的结构有所不同。我们启动工作容器并保留它。然后在容器内,我们只触发增量 XLA 构建,除了第一次构建之外。因此,在第一次迭代之后,速度会快得多。

  • test_runner2.sh

  # Execute this script inside the container:
  # docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
  cd /opt/xla-source
  git remote update
  cd /opt/jax-source
  git remote update
  pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
  cd /tmp
  git clone https://github.com/Autodesk/XLB
  cd XLB

  for d in `seq -w 24 26`; do
      for h in `seq -w 0 24`; do
          echo $m $d $h
          /bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
      done
  done
  • test2.sh

  echo "param: $@"
  cd /opt/xla-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q
  cd /opt/jax-source
  git checkout `git rev-list -1 --before="$*" origin/main`
  git show -q

  rm /opt/jax-source/dist/jax*.whl
  build-jax.sh # The script is in the nightly container

  export PYTHONPATH=.
  export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

  python3 examples/performance/MLUPS3d.py 256 200

现在,您可以在新的输出文件中执行 grep 命令以查看问题出现在哪个时间段之间。

最终验证#

为此,您需要检查这些时间段之间的 JAX 和 XLA 历史记录。可能有一些提交需要测试。如果您想更高级,可以使用 git bisect。

可以改进吗?#

是的!如果这是一个崩溃回归,能够进行二分搜索将非常有用。但这将更加复杂。如果有人想贡献这样的说明,请提交 PR ;)

对于速度回归,二分搜索可能会隐藏一些信息。我们不会很容易地看到这里出现了两个回归。