调查回归#
所以你更新了 JAX 并遇到了速度回归?你有一些时间并准备好调查这个问题了吗?我们首先创建一个 JAX issue。但是如果你能准确指出触发回归的 commit,那将对我们非常有帮助。
本文档解释了我们如何识别导致 15% 性能回归的 commit。
步骤#
如果复现程序足够快,就可以轻松完成。这是一种暴力方法,而不是二分法,但如果复现程序足够快,它效果很好。这确保了你始终测试兼容的 XLA 和 JAX commit。它也限制了 XLA 的重新编译。
这是一个建议的调查策略
你可以对两个版本之间的每日构建容器进行暴力测试。
保持 XLA 和 JAX 同步的情况下,每小时重新编译。
最终验证:可能需要手动检查一些 commit(或使用 git bisect)。
每晚调查#
这可以通过使用 NVIDIA JAX-Toolbox 每晚构建容器来完成。
有时,错误会阻止容器的构建,或者会出现临时的回归。只需忽略那些天即可。
因此,你最终应该得到发生回归的特定一天或几天。
要自动化此过程,你需要 2 个 python 脚本
test_runner.sh:将启动容器和测试。
test.sh:将安装缺失的依赖项并运行测试。
以下是用于该 issue 的真实示例脚本:https://github.com/jax-ml/jax/issues/17686
test_runner.sh
for m in 7 8 9; do
for d in `seq -w 1 30`; do
docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
done
Done
test.sh
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
然后你可以 grep 每个输出,以查看回归发生的时间: grep MLUPS OUT*
。以下是我们得到的结果
OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25: File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634
结果发现 8-24 是好的,但 8-26 是不好的。在 8-25,由于另一个问题导致无法获得结果。因此,我们需要调查 8-24 和 8-26 之间的每小时构建。之前有一个较小的减速,让我们在此示例中忽略它。那将只是那些日期之间的另一次每小时调查。
每小时调查#
这会在两个日期之间的每个小时签出 JAX 和 XLA,重建所有内容并运行测试。脚本结构有所不同。我们启动工作容器并保持它运行。然后,在其中,我们只触发增量 XLA 构建,除了第一次构建。因此,第一次迭代之后会快得多。
test_runner2.sh
# Execute this script inside the container:
# docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
cd /opt/xla-source
git remote update
cd /opt/jax-source
git remote update
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
cd /tmp
git clone https://github.com/Autodesk/XLB
cd XLB
for d in `seq -w 24 26`; do
for h in `seq -w 0 24`; do
echo $m $d $h
/bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
done
done
test2.sh
echo "param: $@"
cd /opt/xla-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
cd /opt/jax-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
rm /opt/jax-source/dist/jax*.whl
build-jax.sh # The script is in the nightly container
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
现在,你可以在新的输出文件上执行 grep 命令,以查看问题出现在哪些小时之间。
最终验证#
通过这些,你需要检查那些小时之间的 JAX 和 XLA 历史记录。可能有一些 commit 需要测试。如果你想使用更高级的方法,可以使用 git bisect。
可以改进吗?#
是的!如果是一个崩溃回归,能够进行二分查找会很有用。但这会更复杂。如果有人想贡献这样的说明,请提交一个 PR ;)
对于速度回归,二分查找可能会隐藏一些信息。我们不会轻易地看到这里有两个回归。