并行编程简介#
本教程介绍了 JAX 中单程序多数据(SPMD)代码的设备并行。SPMD 是一种并行技术,其中相同的计算(例如,神经网络的前向传递)可以在不同的设备(例如,多个 GPU 或 Google TPU)上并行运行不同的输入数据(例如,批次中的不同输入)。
本教程涵盖了三种并行计算模式
通过
jax.jit()
实现自动并行:编译器选择最佳计算策略(又名“编译器掌舵”)。使用
jax.jit()
和jax.lax.with_sharding_constraint()
实现半自动并行使用
jax.experimental.shard_map.shard_map()
进行完全手动并行控制:shard_map
启用每个设备的代码和显式通信集合
通过使用这些 SPMD 的思想,您可以将为单个设备编写的函数转换为可以在多个设备上并行运行的函数。
如果您在 Google Colab 笔记本中运行这些示例,请确保您的硬件加速器是最新的 Google TPU,方法是检查您的笔记本设置:**运行时** > **更改运行时类型** > **硬件加速器** > **TPU v2**(它提供八个设备供使用)。
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
核心概念:数据分片#
以下所有分布式计算方法的核心概念是数据分片,它描述了数据如何在可用设备上布局。
JAX 如何理解数据在设备上的布局?JAX 的数据类型 jax.Array
不可变数组数据结构表示跨一个或多个设备的物理存储的数组,并有助于将并行性作为 JAX 的核心功能。jax.Array
对象的设计考虑了分布式数据和计算。每个 jax.Array
都有一个关联的 jax.sharding.Sharding
对象,它描述了每个全局设备需要哪个全局数据的分片。当您从头开始创建 jax.Array
时,还需要创建它的 Sharding
。
在最简单的情况下,数组在单个设备上分片,如下所示
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
为了更直观地表示存储布局,jax.debug
模块提供了一些辅助函数来可视化数组的分片。例如,jax.debug.visualize_array_sharding()
显示了数组如何在单个设备的内存中存储
jax.debug.visualize_array_sharding(arr)
TPU 0
要创建具有非平凡分片的数组,您可以为数组定义一个 jax.sharding
规范,并将其传递给 jax.device_put()
。
这里,定义一个 NamedSharding
,它指定一个具有命名轴的设备 N 维网格,其中 jax.sharding.Mesh
允许精确的设备放置
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))
将此 Sharding
对象传递给 jax.device_put()
,您可以获得一个分片的数组
arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29. 30. 31.]]
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
这里的设备编号不是按数字顺序排列的,因为网格反映了设备底层的环形拓扑。
1. 通过 jit
实现自动并行#
一旦你有了分片的数据,进行并行计算最简单的方法就是简单地将数据传递给 jax.jit()
编译的函数!在 JAX 中,您只需要指定希望如何对代码的输入和输出进行分区,编译器将计算出如何:1)对内部所有内容进行分区;2)编译设备间通信。
jit
背后的 XLA 编译器包含用于优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法归结为计算跟随数据。
为了演示 JAX 中的自动并行化如何工作,下面是一个使用 jax.jit()
修饰的阶段输出函数的示例:这是一个简单的按元素操作的函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出以相同的方式分片
@jax.jit
def f_elementwise(x):
return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True
随着计算变得更加复杂,编译器会决定如何最好地传播数据的分片。
在这里,您沿着 x
的前导轴求和,并可视化结果值如何在多个设备上存储(使用 jax.debug.visualize_array_sharding()
)
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0,6 TPU 1,7 TPU 2,4 TPU 3,5
[48. 52. 56. 60. 64. 68. 72. 76.]
结果是部分复制的:也就是说,数组的前两个元素在设备 0
和 6
上复制,第二个在 1
和 7
上复制,依此类推。
2. 使用约束实现半自动分片#
如果您想对特定计算中使用的分片进行一些控制,JAX 提供了 with_sharding_constraint()
函数。您可以将 jax.lax.with_sharding_constraint()
(代替 jax.device_put()
)与 jax.jit()
一起使用,以便更好地控制编译器如何约束中间值和输出的分布。
例如,假设在上面的 f_contract
中,您希望输出不是部分复制的,而是在八个设备上完全分片
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
[48. 52. 56. 60. 64. 68. 72. 76.]
这为您提供了一个具有您想要的特定输出分片的函数。
3. 使用 shard_map
实现手动并行#
在上面探讨的自动并行方法中,您可以编写一个函数,就好像您正在对完整数据集进行操作一样,jit
会将该计算拆分到多个设备上。相比之下,使用 jax.experimental.shard_map.shard_map()
,您可以编写一个将处理单个数据分片的函数,而 shard_map
将构建完整的函数。
shard_map
的工作原理是将函数映射到特定的设备网格上(shard_map
映射到分片上)。在下面的示例中
和之前一样,
jax.sharding.Mesh
允许精确的设备放置,并为逻辑轴和物理轴名称提供轴名称参数。in_specs
参数确定分片大小。out_specs
参数标识如何将块重新组合在一起。
注意:如果需要,jax.experimental.shard_map.shard_map()
代码可以在 jax.jit()
内部工作。
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))
f_elementwise_sharded = shard_map(
f_elementwise,
mesh=mesh,
in_specs=P('x'),
out_specs=P('x'))
arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,
-0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,
-0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,
2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,
2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,
0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,
-0.97606325, 0.19192469], dtype=float32)
你编写的函数只“看到”数据的单个批次,你可以通过打印设备局部形状来检查这一点。
x = jnp.arange(32)
print(f"global shape: {x.shape=}")
def f(x):
print(f"device local shape: {x.shape=}")
return x * 2
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)
因为你的每个函数只“看到”数据的设备局部部分,这意味着像聚合这样的函数需要额外考虑。
例如,下面是一个 shard_map
的 jax.numpy.sum()
的样子
def f(x):
return jnp.sum(x, keepdims=True)
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
你的函数 f
在每个分片上单独操作,并且产生的求和反映了这一点。
如果你想跨分片求和,你需要使用像 jax.lax.psum()
这样的集合操作显式地请求它。
def f(x):
sum_in_shard = x.sum()
return jax.lax.psum(sum_in_shard, 'x')
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)
因为输出不再具有分片维度,所以设置 out_specs=P()
(回想一下,out_specs
参数标识了在 shard_map
中如何将块重新组装在一起)。
比较三种方法#
在脑海中有了这些概念后,让我们比较一下用于简单神经网络层的三种方法。
首先像这样定义你的规范函数
@jax.jit
def layer(x, weights, bias):
return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))
layer(x, weights, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
你可以使用 jax.jit()
并传递适当分片的数据,以分布式方式自动运行此函数。
如果你以相同的方式对 x
和 weights
的前导轴进行分片,那么矩阵乘法将自动并行发生。
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
或者,你可以在函数中使用 jax.lax.with_sharding_constraint()
来自动分发未分片的输入。
@jax.jit
def layer_auto(x, weights, bias):
x = jax.lax.with_sharding_constraint(x, sharding)
weights = jax.lax.with_sharding_constraint(weights, sharding)
return layer(x, weights, bias)
layer_auto(x, weights, bias) # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
最后,你可以使用 shard_map
做同样的事情,使用 jax.lax.psum()
来指示矩阵乘法所需的跨分片集合操作。
from functools import partial
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x'), P('x', None), P(None)),
out_specs=P(None))
def layer_sharded(x, weights, bias):
return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
下一步#
本教程简要介绍了 JAX 中的分片和并行计算。
要深入了解每种 SPMD 方法,请查看这些文档