并行编程入门#

本教程是 JAX 中单程序多数据 (SPMD) 代码的设备并行入门。SPMD 是一种并行技术,其中相同的计算(例如神经网络的前向传递)可以在不同的设备(例如多个 GPU 或 Google TPU)上并行运行不同的输入数据(例如,批处理中的不同输入)。

本教程涵盖三种并行计算模式

使用这些 SPMD 的思想,您可以将为单个设备编写的函数转换为可以在多个设备上并行运行的函数。

如果您在 Google Colab 笔记本中运行这些示例,请通过检查笔记本设置来确保您的硬件加速器是最新的 Google TPU:运行时 > 更改运行时类型 > 硬件加速器 > TPU v2 (它提供八个设备来使用)。

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

关键概念:数据分片#

下面所有分布式计算方法的关键是数据分片的概念,它描述了数据如何在可用设备上布局。

JAX 如何理解数据在设备上的布局?JAX 的数据类型 jax.Array 不可变数组数据结构,表示跨一个或多个设备的物理存储的数组,并有助于使并行性成为 JAX 的核心功能。jax.Array 对象的设计考虑了分布式数据和计算。每个 jax.Array 都有一个关联的 jax.sharding.Sharding 对象,它描述了每个全局设备需要哪个全局数据的分片。当您从头开始创建一个 jax.Array 时,您还需要创建它的 Sharding

在最简单的情况下,数组在单个设备上进行分片,如下所示

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))

为了更直观地表示存储布局,jax.debug 模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding() 显示了数组如何存储在单个设备的内存中

jax.debug.visualize_array_sharding(arr)
                                                  
                                                  
                                                  
                                                  
                                                  
                      TPU 0                       
                                                  
                                                  
                                                  
                                                  
                                                  

要创建一个具有非平凡分片的数组,您可以为该数组定义一个 jax.sharding 规范,并将其传递给 jax.device_put()

在这里,定义一个 NamedSharding,它指定一个具有命名轴的 N 维设备网格,其中 jax.sharding.Mesh 允许精确的设备放置

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))

将此 Sharding 对象传递给 jax.device_put(),您可以获得一个分片的数组

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]
                                                
                                                
   TPU 0       TPU 1       TPU 2       TPU 3    
                                                
                                                
                                                
                                                
                                                
   TPU 6       TPU 7       TPU 4       TPU 5    
                                                
                                                
                                                

这里的设备编号不是按数字顺序排列的,因为网格反映了设备底层的环形拓扑结构。

1. 通过 jit 实现自动并行#

一旦您有了分片数据,执行并行计算的最简单方法就是简单地将数据传递给 jax.jit() 编译的函数!在 JAX 中,您只需要指定您希望如何对代码的输入和输出进行分区,编译器将找出如何:1) 对内部的所有内容进行分区;2) 编译设备间的通信。

jit 背后的 XLA 编译器包含用于优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法可以归结为计算跟随数据

为了演示 JAX 中自动并行化是如何工作的,下面是一个示例,它使用 jax.jit() 修饰的暂存函数:它是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出以相同的方式进行分片

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True

随着计算变得越来越复杂,编译器会决定如何最好地传播数据的分片。

在这里,您沿着 x 的前导轴求和,并可视化结果值如何存储在多个设备上(使用 jax.debug.visualize_array_sharding()

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
 TPU 0,6  TPU 1,7  TPU 2,4  TPU 3,5 
                                    
[48. 52. 56. 60. 64. 68. 72. 76.]

结果是部分复制的:也就是说,数组的前两个元素在设备 06 上复制,第二个元素在 17 上复制,依此类推。

2. 使用约束的半自动分片#

如果您想对特定计算中使用的分片进行一些控制,JAX 提供了 with_sharding_constraint() 函数。您可以将 jax.lax.with_sharding_constraint() (代替 jax.device_put()) 与 jax.jit() 一起使用,以更好地控制编译器如何约束中间值和输出的分布。

例如,假设在上面的 f_contract 中,您希望输出不是部分复制的,而是完全在八个设备上分片的

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  
                                                                        
[48. 52. 56. 60. 64. 68. 72. 76.]

这为您提供了一个具有您想要的特定输出分片的函数。

3. 使用 shard_map 的手动并行#

在上面探索的自动并行方法中,您可以像操作完整数据集一样编写函数,而 jit 会将该计算拆分到多个设备上。相比之下,使用 jax.experimental.shard_map.shard_map(),您可以编写处理单个数据分片的函数,而 shard_map 将构造完整的函数。

shard_map 通过在特定的设备网格上映射函数来工作(shard_map 在分片上映射)。在下面的示例中

  • 和以前一样,jax.sharding.Mesh 允许精确的设备放置,并为逻辑轴和物理轴名称提供轴名称参数。

  • in_specs 参数确定分片大小。out_specs 参数标识如何将块重新组合在一起。

注意: 如果需要,jax.experimental.shard_map.shard_map() 代码可以在 jax.jit() 内部工作。

from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32)

您编写的函数只“看到”单个数据批次,您可以通过打印设备本地形状来检查

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)

由于您的每个函数只“看到”数据的设备本地部分,这意味着类似聚合的函数需要一些额外的考虑。

例如,下面是 jax.numpy.sum()shard_map 的样子

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)

您的函数 f 单独在每个分片上操作,并且产生的求和反映了这一点。

如果要跨分片求和,则需要使用像 jax.lax.psum() 这样的集体操作显式请求它

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)

由于输出不再具有分片维度,请设置 out_specs=P() (回想一下 out_specs 参数标识如何在 shard_map 中将块重新组合在一起)。

比较这三种方法#

考虑到这些概念,让我们比较一下简单神经网络层的三种方法。

首先定义您的规范函数,如下所示

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

您可以使用 jax.jit() 并传递适当分片的数据,以分布式方式自动运行此操作。

如果您以相同的方式对 xweights 的前导轴进行分片,则矩阵乘法将自动并行发生

mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)

或者,您可以在函数中使用 jax.lax.with_sharding_constraint() 来自动分发未分片的输入

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

最后,您可以使用 shard_map 执行相同的操作,使用 jax.lax.psum() 来指示矩阵乘法所需的跨分片集体操作

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)

下一步#

本教程简要介绍了 JAX 中的分片和并行计算。

要深入了解每种 SPMD 方法,请查看以下文档