分布式数组和自动并行化#
本教程讨论了通过 jax.Array
进行并行化,这是 JAX v0.4.1 及更高版本中可用的统一数组对象模型。
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
⚠️ 警告:该 notebook 需要 8 个设备才能运行。
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
介绍和一个快速示例#
通过阅读本教程 notebook,你将了解 jax.Array
,这是一种用于表示数组的统一数据类型,即使物理存储跨多个设备也是如此。你还将学习如何将 jax.Array
与 jax.jit
结合使用可以提供基于编译器的自动并行化。
在我们逐步思考之前,这里有一个快速示例。首先,我们将创建一个跨多个设备分片的 jax.Array
from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
接下来,我们将对其应用计算并可视化结果值如何在多个设备上存储
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
对 jnp.sin
应用的评估已在存储输入值(和输出值)的设备上自动并行化
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
现在让我们更详细地了解这些部分!
计算遵循数据分片并自动并行化#
有了分片的输入数据,编译器就可以为我们提供并行计算。特别是,用 jax.jit
修饰的函数可以在分片数组上运行,而无需将数据复制到单个设备上。相反,计算遵循分片:基于输入数据的分片,编译器为中间值和输出值确定分片,并并行化它们的评估,甚至在必要时插入通信操作。
例如,最简单的计算是逐元素的计算
mesh = jax.make_mesh((4, 2), ('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里对于逐元素操作 jnp.sin
,编译器选择的输出分片与输入相同。此外,编译器自动并行化了计算,因此每个设备都并行地从其输入分片计算其输出分片。
换句话说,即使我们编写 jnp.sin
计算时假设由一台机器执行,编译器也会为我们拆分计算并在多个设备上执行。
我们也可以对不仅仅是逐元素操作执行相同的操作。考虑一个带有分片输入的矩阵乘法
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里,编译器选择了输出分片,以便它可以最大限度地并行化计算:无需通信,每个设备都已经拥有计算其输出分片所需的输入分片。
我们如何确定它实际上是并行运行的?我们可以做一个简单的计时实验
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0 │
│ │
│ │
│ │
│ │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
甚至复制分片的 Array
也会产生与输入分片相同的结果
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
因此,计算遵循数据放置:当我们使用 jax.device_put
显式地对数据进行分片,并将函数应用于该数据时,编译器会尝试并行化计算并确定输出分片。这种分片数据的策略是对 JAX 的遵循显式设备放置的策略 的概括。
当显式分片不一致时,JAX 会报错#
但是,如果对计算的两个参数显式地放置在不同的设备集上,或者具有不兼容的设备顺序会发生什么?在这些模糊的情况下,会引发错误
import textwrap
from termcolor import colored
def print_exception(e):
name = colored(f'{type(e).__name__}', 'red', force_color=True)
print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU
我们说已使用 jax.device_put
显式放置或分片的数组 committed 到其设备,因此不会自动移动。有关更多信息,请参见设备放置常见问题解答。
当数组 没有 使用 jax.device_put
显式放置或分片时,它们会 uncommitted 放置在默认设备上。与 committed 数组不同,uncommitted 数组可以自动移动和重新分片:也就是说,即使其他参数显式放置在不同的设备上,uncommitted 数组也可以作为计算的参数。
例如,jnp.zeros
,jnp.arange
和 jnp.array
的输出是 uncommitted 的
y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!
约束 jit
代码中中间体的分片#
虽然编译器会尝试决定如何对函数的中间值和输出进行分片,但我们也可以使用 jax.lax.with_sharding_constraint
为其提供提示。使用 jax.lax.with_sharding_constraint
非常类似于 jax.device_put
,只不过我们是在 staged-out(即 jit
修饰)函数内部使用它
mesh = jax.make_mesh((4, 2), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3,4,5,6,7 │ │ │ │ │ │ │ │ │ └───────────────────────┘
通过添加 with_sharding_constraint
,我们约束了输出的分片。除了尊重特定中间值上的注释外,编译器还将使用注释来决定其他值的分片。
通常,最好根据最终如何使用值来注释计算的输出。
示例:神经网络#
⚠️ 警告:以下内容旨在简单演示使用 jax.Array
进行自动分片传播,但这可能不反映实际示例的最佳实践。 例如,实际示例可能需要更多地使用 with_sharding_constraint
。
我们可以使用 jax.device_put
和 jax.jit
的计算跟随分片功能来并行化神经网络中的计算。以下是一些基于这个基本神经网络的简单示例:
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init_model(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8 路批数据并行#
mesh = jax.make_mesh((8,), ('batch',))
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
4 路批数据并行和 2 路模型张量并行#
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)
W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))
W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)
W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
print(loss_jit(params, batch))
10.760109
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)