异步调度

异步调度#

JAX 使用异步调度来隐藏 Python 的开销。考虑以下程序

>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.  
Array([[258.01971436, 249.64862061, 257.13372803, ...,
        236.67948914, 250.68939209, 241.36853027],
        [265.65979004, 256.28912354, 262.18252563, ...,
        242.03181458, 256.16757202, 252.44122314],
        [262.38916016, 255.72747803, 261.23059082, ...,
        240.83563232, 255.41094971, 249.62471008],
        ...,
        [259.15814209, 253.09197998, 257.72174072, ...,
        242.23876953, 250.72680664, 247.16642761],
        [271.22662354, 261.91204834, 265.33398438, ...,
        248.26651001, 262.05389404, 261.33700562],
        [257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
        248.62597656, 243.22348022]], dtype=float32)

当执行诸如 jnp.dot(x, x) 之类的操作时,JAX 不会等待操作完成才将控制权返回给 Python 程序。相反,JAX 返回一个 jax.Array 值,它是一个期物,即将来会在加速器设备上生成的值,但并不一定立即可用。我们可以检查 jax.Array 的形状或类型,而无需等待生成它的计算完成,我们甚至可以将其传递给另一个 JAX 计算,就像这里对加法操作所做的那样。只有当我们实际上从主机检查数组的值时,例如通过打印它或将其转换为普通的 numpy.ndarray,JAX 才会强制 Python 代码等待计算完成。

异步调度很有用,因为它允许 Python 代码“领先于”加速器设备,使 Python 代码脱离关键路径。如果 Python 代码在设备上入队工作的速度快于执行速度,并且 Python 代码实际上不需要检查主机上计算的输出,那么 Python 程序就可以入队任意数量的工作,并避免加速器等待。

异步调度对微基准测试产生了一个略微出乎意料的结果。

>>> %time jnp.dot(x, x)  
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
        233.67948914, 247.68939209, 238.36853027],
        [262.65979004, 253.28910828, 259.18252563, ...,
        239.03181458, 253.16757202, 249.44122314],
        [259.38916016, 252.72747803, 258.23059082, ...,
        237.83563232, 252.41094971, 246.62471008],
        ...,
        [256.15814209, 250.09197998, 254.72172546, ...,
        239.23876953, 247.72680664, 244.16642761],
        [268.22662354, 258.91204834, 262.33398438, ...,
        245.26651001, 259.05389404, 258.33700562],
        [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
        245.62597656, 240.22348022]], dtype=float32)

269µs 对于 CPU 上 1000x1000 矩阵乘法来说是一个出乎意料的小时间!但是事实证明,异步调度误导了我们,我们没有计时矩阵乘法的执行时间,而只是计时了调度工作的时间。为了测量操作的真实成本,我们必须要么在主机上读取值(例如,将其转换为普通的宿主端 numpy 数组),要么使用 block_until_ready() 方法在 jax.Array 值上等待生成它的计算完成。

>>> %time np.asarray(jnp.dot(x, x))  
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
        238.36853],
       [262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
        249.44122],
       [259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
        246.62471],
       ...,
       [256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
        244.16643],
       [268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
        258.337  ],
       [254.16135, 251.75433, 256.083  , ..., 238.59848, 245.62598,
        240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()  
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
        233.67948914, 247.68939209, 238.36853027],
        [262.65979004, 253.28910828, 259.18252563, ...,
        239.03181458, 253.16757202, 249.44122314],
        [259.38916016, 252.72747803, 258.23059082, ...,
        237.83563232, 252.41094971, 246.62471008],
        ...,
        [256.15814209, 250.09197998, 254.72172546, ...,
        239.23876953, 247.72680664, 244.16642761],
        [268.22662354, 258.91204834, 262.33398438, ...,
        245.26651001, 259.05389404, 258.33700562],
        [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
        245.62597656, 240.22348022]], dtype=float32)

在不将结果传回 Python 的情况下阻塞通常更快,并且在编写计算时间的微基准测试时通常是最佳选择。