异步调度#
JAX 使用异步调度来隐藏 Python 开销。考虑以下程序
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.
Array([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
当执行诸如 jnp.dot(x, x)
的操作时,JAX 不会等待操作完成才将控制权返回给 Python 程序。相反,JAX 返回一个 jax.Array
值,这是一个 future,即一个将在未来在加速器设备上产生但不必立即可用的值。我们可以在不等待产生它的计算完成的情况下检查 jax.Array
的形状或类型,我们甚至可以将其传递给另一个 JAX 计算,就像我们在这里对加法运算所做的那样。只有当我们实际从主机检查数组的值时,例如通过打印它或将其转换为普通的 numpy.ndarray
时,JAX 才会强制 Python 代码等待计算完成。
异步调度非常有用,因为它允许 Python 代码“超前于”加速器设备运行,使 Python 代码不处于关键路径上。只要 Python 代码在设备上的工作入队速度快于其执行速度,并且只要 Python 代码实际上不需要检查主机上计算的输出,那么 Python 程序就可以入队任意数量的工作,并避免加速器等待。
异步调度对微基准测试产生了一个稍微令人惊讶的后果。
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
对于 CPU 上的 1000x1000 矩阵乘法来说,269 微秒是一个非常小的运行时间!然而,事实证明,异步调度正在误导我们,我们测量的不是矩阵乘法的执行时间,而只是调度工作的时间。要测量操作的真实成本,我们必须读取主机上的值(例如,将其转换为普通的宿主机 numpy 数组),或者使用 block_until_ready()
方法在 jax.Array
值上等待产生它的计算完成。
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
在不将结果传输回 Python 的情况下进行阻塞通常更快,并且在编写计算时间的微基准测试时通常是最佳选择。