分布式数组和自动并行化#

Open in Colab Open in Kaggle

本教程讨论了通过 jax.Array 进行并行化,这是 JAX v0.4.1 及更高版本中可用的统一数组对象模型。

from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

⚠️ 警告:此笔记本需要 8 个设备才能运行。

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

引言和一个快速示例#

通过阅读本教程笔记本,您将了解 jax.Array,这是一种用于表示数组的统一数据类型,即使物理存储跨越多个设备。您还将了解如何使用 jax.Arrayjax.jit 结合使用,从而实现基于编译器的自动并行化。

在我们逐步思考之前,这里有一个快速示例。首先,我们将创建一个跨多个设备分片的 jax.Array

from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

接下来,我们将对其应用计算,并可视化结果值如何在多个设备上存储

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

jnp.sin 应用的评估已在存储输入值(和输出值)的设备上自动并行化

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)

现在让我们更详细地了解每个部分!

分片 描述了数组值在设备之间内存中的布局方式#

分片基础知识和 NamedSharding 子类#

为了在多个设备上并行化计算,我们首先必须将输入数据分布在多个设备上。

在 JAX 中,Sharding 对象描述了分布式内存布局。它们可以与 jax.device_put 一起使用,以生成具有分布式布局的值。

例如,这里有一个具有单设备 Sharding 的值

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

在这里,我们使用 jax.debug.visualize_array_sharding 函数来显示值 x 在内存中的存储位置。 x 的所有内容都存储在单个设备上,因此可视化非常枯燥!

但是我们可以通过使用 jax.device_putSharding 对象在多个设备上对 x 进行分片。首先,我们使用 jax.make_mesh 创建一个 Devicesnumpy.ndarray,该函数考虑了硬件拓扑以确定 Device 的顺序

from jax.sharding import Mesh, PartitionSpec, NamedSharding

P = PartitionSpec

mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

我们可以定义一个辅助函数来简化操作

default_mesh = jax.make_mesh((4, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

在这里,我们使用 P('a', 'b') 来表示 x 的第一轴和第二轴应分别在设备网格轴 'a''b' 上进行分片。我们可以轻松切换到 P('b', 'a'),以在不同的设备上对 x 的轴进行分片

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘

在这里,由于 P('a', None) 没有提及 Mesh 轴名称 'b',因此我们在轴 'b' 上获得了复制。这里的 None 只是充当占位符,以与值 x 的第二轴对齐,而没有表示在任何网格轴上进行分片。(作为简写,可以省略尾随的 None,因此 P('a', None)P('a') 的含义相同。但明确指出没有坏处!)

要仅在 x 的第二轴上进行分片,我们可以在 PartitionSpec 中使用 None 占位符

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘

对于固定的网格,我们甚至可以在多个设备网格轴上划分 x 的一个逻辑轴

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘

使用 NamedSharding 可以很容易地定义一次设备网格并为其轴命名,然后只需在每个 device_putPartitionSpec 中引用这些名称即可。

计算遵循数据分片并自动并行化#

通过分片的输入数据,编译器可以为我们提供并行计算。特别是,用 jax.jit 修饰的函数可以在分片的数组上操作,而无需将数据复制到单个设备上。相反,计算遵循分片:根据输入数据的分片,编译器决定中间值和输出值的分片,并并行化它们的评估,甚至在必要时插入通信操作。

例如,最简单的计算是逐元素的计算

mesh = jax.make_mesh((4, 2), ('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

这里,对于逐元素操作 jnp.sin,编译器选择输出分片与输入相同。此外,编译器自动并行化了计算,以便每个设备并行地从其输入分片计算其输出分片。

换句话说,即使我们编写 jnp.sin 计算时假设单个机器执行它,编译器也会为我们分割计算并在多个设备上执行它。

我们也可以对不仅仅是逐元素的操作执行相同的操作。考虑一个具有分片输入的矩阵乘法

y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

在这里,编译器选择了输出分片,以便可以最大程度地并行化计算:无需通信,每个设备已经具有计算其输出分片所需的输入分片。

我们如何确定它实际上是并行运行的?我们可以做一个简单的计时实验

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)

即使复制分片的 Array 也会产生一个具有输入分片的结果

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

因此,计算遵循数据放置:当我们使用 jax.device_put 显式分片数据,并将函数应用于该数据时,编译器会尝试并行化计算并决定输出分片。这种针对分片数据的策略是 JAX 遵循显式设备放置的策略的概括。

当显式分片不一致时,JAX 会报错#

但是,如果计算的两个参数被显式放置在不同的设备集上,或者具有不兼容的设备顺序怎么办?在这些不明确的情况下,会引发错误

import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red', force_color=True)
  print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU

我们说已使用 jax.device_put 显式放置或分片的数组将 提交 到其设备,因此不会自动移动。有关更多信息,请参阅设备放置 FAQ

当数组 使用 jax.device_put 显式放置或分片时,它们会 未提交 地放置在默认设备上。与已提交的数组不同,未提交的数组可以自动移动和重新分片:也就是说,即使其他参数显式放置在不同的设备上,未提交的数组也可以作为计算的参数。

例如,jnp.zerosjnp.arangejnp.array 的输出是未提交的

y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!

约束 jit 代码中中间体的分片#

虽然编译器会尝试确定函数的中间值和输出应如何分片,但我们也可以使用 jax.lax.with_sharding_constraint 向其提供提示。使用 jax.lax.with_sharding_constraintjax.device_put 非常相似,只不过我们在暂存(即用 jit 修饰)的函数内部使用它

mesh = jax.make_mesh((4, 2), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│  TPU 0,1,2,3,4,5,6,7  │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

通过添加 with_sharding_constraint,我们约束了输出的分片。除了尊重特定中间体的注释外,编译器还将使用注释来确定其他值的分片。

通常,根据值的最终使用方式来注释计算的输出是一个好习惯。

示例:神经网络#

⚠️ 警告:以下内容旨在简单演示使用 jax.Array 的自动分片传播,但它可能不反映实际示例的最佳实践。 例如,实际示例可能需要更多使用 with_sharding_constraint

我们可以使用 jax.device_putjax.jit 的计算遵循分片功能来并行化神经网络中的计算。以下是一些基于此基本神经网络的简单示例

import jax
import jax.numpy as jnp
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

8 路批数据并行#

mesh = jax.make_mesh((8,), ('batch',))
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

4 路批数据并行和 2 路模型张量并行#

mesh = jax.make_mesh((4, 2), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 6,7│
├───────┤
│TPU 4,5│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 6,7│
├───────┤
│TPU 4,5│
└───────┘
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
print(loss_jit(params, batch))
10.760109
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

注意事项#

生成随机数#

JAX 附带一个功能性、确定性的 随机数生成器。它是 jax.random 模块中各种采样函数的基础,例如 jax.random.uniform

JAX 的随机数由基于计数器的 PRNG 生成,因此原则上,随机数生成应该是计数器值的纯映射。原则上,纯映射是一个可简单分区的操作。它应该不需要跨设备通信,也不需要跨设备进行任何冗余计算。

然而,由于历史原因,现有的稳定 RNG 实现并非自动可分区的。

考虑以下示例,其中一个函数抽取随机均匀数并将它们逐元素添加到输入中

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
mesh = Mesh(jax.devices(), 'x')
x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)

在分区输入上,函数 f 生成的输出也是分区的

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

但是,如果我们检查在分区输入上 f 的已编译计算,我们会发现它确实涉及一些通信

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True

一种解决此问题的方法是使用实验性升级标志 jax_threefry_partitionable 配置 JAX。启用该标志后,已编译的计算中现在不再有“集体置换”操作

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False

输出仍然是分区的

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

然而,jax_threefry_partitionable 选项的一个注意事项是,即使它们是由同一个随机密钥生成的,所产生的随机值也可能与未设置该标志时不同

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]

jax_threefry_partitionable 模式下,JAX PRNG 仍然是确定性的,但其实现是新的(并且正在开发中)。对于给定的密钥,生成的随机值在给定的 JAX 版本(或 main 分支上的给定提交)中将是相同的,但在不同版本之间可能会有所不同。