并行编程简介#

本教程介绍了 JAX 中单程序多数据 (SPMD) 代码的设备并行性。SPMD 是一种并行技术,其中相同的计算(例如,神经网络的前向传递)可以在不同的输入数据(例如,批次中的不同输入)上并行地在不同的设备(例如,多个 GPU 或 Google TPU)上运行。

本教程涵盖了三种并行计算模式

使用这些 SPMD 思想,您可以将为单个设备编写的函数转换为可在多个设备上并行运行的函数。

如果您在 Google Colab 笔记本中运行这些示例,请确保您的硬件加速器是最新的 Google TPU。为此,请检查笔记本设置:运行时 > 更改运行时类型 > 硬件加速器 > TPU v2(它提供八个设备供您使用)。

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

关键概念:数据分片#

所有以下分布式计算方法的关键都是数据分片的概念,它描述了数据如何在可用设备上布局。

JAX 如何理解数据如何在设备之间布局?JAX 的数据类型,即 jax.Array 不可变数组数据结构,表示跨一个或多个设备具有物理存储的数组,并有助于使并行成为 JAX 的核心功能。 jax.Array 对象的设计考虑了分布式数据和计算。每个 jax.Array 都与一个关联的 jax.sharding.Sharding 对象相关联,该对象描述了每个全局设备所需的全局数据的哪个分片。当您从头开始创建 jax.Array 时,您还需要创建它的 Sharding

在最简单的情况下,数组将在单个设备上进行分片,如下所示

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

为了更直观地表示存储布局,jax.debug 模块提供了一些辅助方法来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示了数组如何在单个设备的内存中存储

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      TPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建具有非平凡分片的数组,您可以为数组定义 jax.sharding 规范,并将此规范传递给 jax.device_put().

在此,定义 NamedSharding,它指定具有命名轴的设备的 N 维网格,其中 jax.sharding.Mesh 允许精确的设备放置

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))

将此 Sharding 对象传递给 jax.device_put(),您可以获得一个分片数组

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   TPU 0       TPU 1       TPU 2       TPU 3    
                                                
                                                
                                                
                                                
                                                
   TPU 6       TPU 7       TPU 4       TPU 5    
                                                
                                                
                                                

此处的设备编号不是按数字顺序排列的,因为网格反映了设备的底层环形拓扑结构。

1. 通过 jit 实现自动并行#

获得分片数据后,执行并行计算的最简单方法是将数据传递给 jax.jit() 编译的函数!在 JAX 中,您只需要指定希望如何对代码的输入和输出进行分区,编译器将弄清楚如何:1) 对内部的所有内容进行分区;以及 2) 编译设备间通信。

jit 背后的 XLA 编译器包含用于优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法归结为计算遵循数据

为了演示 JAX 中的自动并行化工作原理,以下示例使用 jax.jit() 装饰的阶段化函数:它是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出将以相同的方式进行分片

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得更加复杂,编译器将决定如何最佳地传播数据的分片。

在此,您沿 x 的前导轴求和,并可视化结果值如何在多个设备之间存储(使用 jax.debug.visualize_array_sharding()

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 TPU 0,6  TPU 1,7  TPU 2,4  TPU 3,5 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果被部分复制:也就是说,数组的前两个元素在设备 06 上复制,第二个在 17 上复制,依此类推。

2. 带约束的半自动分片#

如果您希望对特定计算中使用的分片进行一些控制,JAX 提供了 with_sharding_constraint() 函数。您可以使用 jax.lax.with_sharding_constraint()(代替 jax.device_put())以及 jax.jit() 来更好地控制编译器如何约束中间值和输出的分布。

例如,假设在上面的 f_contract 中,您更希望输出不被部分复制,而是完全跨八个设备进行分片

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  mesh = jax.make_mesh((8,), ('x',))
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  
                                                                        
[48. 52. 56. 60. 64. 68. 72. 76.]

这为您提供了具有您想要的特定输出分片的函数。

3. 使用 shard_map 的手动并行#

在上面探讨的自动并行方法中,您可以像操作完整数据集一样编写函数,并且 jit 将在多个设备上拆分该计算。相比之下,使用 jax.experimental.shard_map.shard_map(),您将编写处理单个数据分片的函数,并且 shard_map 将构建完整的函数。

shard_map 通过将函数映射到特定设备网格shard_map 映射到分片)来工作。在以下示例中

  • 与之前一样,jax.sharding.Mesh 允许精确的设备放置,其中逻辑和物理轴名的轴名参数。

  • in_specs 参数决定了分片大小。 out_specs 参数标识了块如何重新组合在一起。

注意: 如果需要,jax.experimental.shard_map.shard_map() 代码可以在 jax.jit() 内部工作。

from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

您编写的函数只“看到”一批数据,您可以通过打印设备本地形状来检查这一点

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

由于您的每个函数只“看到”数据的设备本地部分,这意味着类似聚合的函数需要一些额外的考虑。

例如,以下是 shard_mapjax.numpy.sum() 的示例

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

您的函数 f 在每个分片上独立运行,并且产生的总和反映了这一点。

如果您想跨分片求和,您需要使用像 jax.lax.psum() 这样的集体操作来显式请求它

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

由于输出不再具有分片维度,因此设置 out_specs=P()(回想一下,out_specs 参数标识了块如何在 shard_map 中重新组合在一起)。

比较三种方法#

牢记这些概念,让我们比较三种用于简单神经网络层的解决方法。

从定义您的规范函数开始,如下所示

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

您可以使用 jax.jit() 并传递适当的分片数据以自动以分布式方式运行此函数。

如果您以相同的方式对 xweights 的前导轴进行分片,则矩阵乘法将自动并行发生

mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

或者,您可以在函数中使用 jax.lax.with_sharding_constraint() 来自动分发未分片输入

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

最后,您可以对 shard_map 做同样的事情,使用 jax.lax.psum() 来指示矩阵乘法所需的跨分片集体操作。

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

下一步#

本教程简要介绍了 JAX 中的分片和并行计算。

要深入了解每种 SPMD 方法,请查看以下文档