jax.pmap

内容

jax.pmap#

jax.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]#

支持集体操作的并行映射。

函数 pmap() 的目的是表达单程序多数据 (SPMD) 程序。 将 pmap() 应用于函数将使用 XLA 编译该函数(类似于 jit()),然后在 XLA 设备(例如多个 GPU 或多个 TPU 内核)上并行执行它。 从语义上讲,它类似于 vmap(),因为这两个转换都将函数映射到数组轴上,但 vmap() 通过将映射的轴推入基本操作来向量化函数,而 pmap() 则复制函数并在自己的 XLA 设备上并行执行每个副本。

映射的轴大小必须小于或等于可用的本地 XLA 设备数量,如 jax.local_device_count() 返回(除非指定了 devices,见下文)。 对于嵌套的 pmap() 调用,映射的轴大小的乘积必须小于或等于 XLA 设备的数量。

注意

pmap() 会编译 fun,因此虽然它可以与 jit() 组合使用,但这通常是不必要的。

pmap() 要求所有参与的设备都相同。 例如,无法使用 pmap() 在两个不同型号的 GPU 上并行化计算。 目前,同一个设备在同一个 pmap 中两次参与是一个错误。

多进程平台:在多进程平台(如 TPU 集群)上,pmap() 旨在用于 SPMD Python 程序,其中每个进程都运行相同的 Python 代码,因此所有进程按相同顺序运行相同的 pmapped 函数。 每个进程仍然应该以映射的轴大小等于本地设备的数量(除非指定了 devices,见下文)来调用 pmapped 函数,并且将像往常一样返回具有相同前导轴大小的数组。 但是,fun 中的任何集体操作都将通过设备到设备通信在所有参与的设备上执行,包括其他进程上的设备。 从概念上讲,这可以被认为是在跨进程分片的单个数组上运行 pmap,其中每个进程仅“看到”其本地分片输入和输出。 SPMD 模型要求在所有设备上以相同顺序运行相同的多进程 pmaps,但它们可以与在单个进程中运行的任意操作交织在一起。

参数:
  • fun (Callable) – 要映射到参数轴上的函数。 它的参数和返回值应该是数组、标量或(嵌套)标准 Python 容器(元组/列表/字典)。 由 static_broadcasted_argnums 指示的位置参数可以是任何东西,只要它们是可哈希的并且定义了相等操作。

  • axis_name (AxisName | None | None) – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合操作。

  • in_axes – 一个非负整数、None 或其嵌套的 Python 容器,用于指定要映射的位置参数的哪些轴。 作为关键字传递的参数始终映射到它们的领先轴(即轴索引 0)。 有关详细信息,请参见 vmap()

  • out_axes – 一个非负整数、None 或其嵌套的 Python 容器,指示映射的轴应该出现在输出中的位置。 所有具有映射轴的输出都必须具有非 None out_axes 规范(参见 vmap())。

  • static_broadcasted_argnums (int | Iterable[int]) –

    一个 int 或 int 集合,指定哪些位置参数视为静态(编译时常量)。 仅依赖于静态参数的操作将被常量折叠。 使用这些常量的不同值调用 pmapped 函数将触发重新编译。 如果 pmapped 函数调用的位置参数少于 static_broadcasted_argnums 指示的,则会引发错误。 每个静态参数都将被广播到所有设备。 不是数组或其容器的参数必须标记为静态。 默认为 ()。

    静态参数必须是可哈希的,这意味着 __hash____eq__ 都已实现,并且应该是不可变的。

  • devices (Sequence[xc.Device] | None | None) – 这是一个实验性功能,API 可能发生变化。 可选,一个要映射的设备序列。 (可用的设备可以通过 jax.devices() 获取)。 必须在多进程设置中为每个进程以相同的方式给出(因此将包括跨进程的设备)。 如果指定,则映射的轴的大小必须等于给定进程的本地设备序列中的设备数量。 尚未支持在内部或外部 pmap() 中都指定了 devices 的嵌套 pmap()

  • backend (str | None | None) – 这是一个实验性功能,API 可能发生变化。 可选,一个表示 XLA 后端的字符串。 'cpu'、'gpu' 或 'tpu'。

  • axis_size (int | None | None) – 可选;映射的轴的大小。

  • donate_argnums (int | Iterable[int]) –

    指定哪些位置参数缓冲区“捐赠”给计算。 如果您在计算完成之后不再需要参数缓冲区,那么捐赠它们是安全的。 在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如,回收您的一个输入缓冲区来存储结果。 您不应该重用您捐赠给计算的缓冲区,如果您尝试这样做,JAX 将会引发错误。 请注意,donate_argnums 仅适用于位置参数,关键字参数不会被捐赠。

    有关缓冲区捐赠的更多详细信息,请参见 FAQ

  • global_arg_shapes (tuple[tuple[int, ...], ...] | None | None)

返回值:

一个并行化的 fun 版本,其参数对应于 fun 的参数,但在 in_axes 指示的位置具有额外的数组轴,并且输出具有额外的领先数组轴(具有相同的大小)。

返回类型:

任何

例如,假设有 8 个 XLA 设备可用,pmap() 可以用作沿着领先数组轴的映射

>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8))  
>>> print(out)  
[0, 1, 4, 9, 16, 25, 36, 49]

当领先维度小于可用设备数量时,JAX 只会在设备的子集上运行

>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y)  
>>> print(out)  
[[[    4.     9.]
  [   12.    29.]]
 [[  244.   345.]
  [  348.   493.]]
 [[ 1412.  1737.]
  [ 1740.  2141.]]]

如果您的领先维度大于可用设备的数量,您将收到错误

>>> pmap(lambda x: x ** 2)(jnp.arange(9))  
ValueError: ... requires 9 replicas, but only 8 XLA devices are available

vmap() 一样,在 in_axes 中使用 None 表示参数没有额外的轴,应该被广播而不是映射到副本

>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)  
>>> print(out)  
([4., 5.], [8., 8.])

请注意,pmap() 始终返回映射到它们领先轴的值,等效于在 vmap() 中使用 out_axes=0

除了表达纯映射,pmap() 也可用于表达通过集体操作进行通信的并行单程序多数据 (SPMD) 程序。 例如

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(jnp.arange(4.))  
>>> print(out)  
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())  
1.0

在本例中,axis_name 是一个字符串,但它可以是任何定义了 __hash____eq__ 的 Python 对象。

参数 axis_name 传递给 pmap() 用于命名映射的轴,以便像 jax.lax.psum() 这样的集合操作可以引用它。轴名称在嵌套的 pmap() 函数中尤为重要,因为在这些函数中,集合操作可以在不同的轴上进行。

>>> from functools import partial
>>> import jax
>>>
>>> @partial(pmap, axis_name='rows')
... @partial(pmap, axis_name='cols')
... def normalize(x):
...   row_normed = x / jax.lax.psum(x, 'rows')
...   col_normed = x / jax.lax.psum(x, 'cols')
...   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
...   return row_normed, col_normed, doubly_normed
>>>
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)  
>>> print(row_normed.sum(0))  
[ 1.  1.]
>>> print(col_normed.sum(1))  
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))  
1.0

在多进程平台上,集合操作将在所有设备上进行,包括其他进程上的设备。例如,假设以下代码在两个进程上运行,每个进程有 4 个 XLA 设备。

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data)  
>>> print(out)  
[28 29 30 31] # on process 0
[32 33 34 35] # on process 1

每个进程传递一个不同的长度为 4 的数组,对应于其 4 个本地设备,psum 操作在所有 8 个值上进行。从概念上讲,这两个长度为 4 的数组可以被认为是一个长度为 8 的分片数组(在本例中相当于 jnp.arange(8)),它被映射,长度为 8 的映射轴被命名为 'i'。每个进程上的 pmap 调用然后返回相应的长度为 4 的输出分片。

参数 devices 可用于指定运行并行计算所使用的确切设备。例如,同样假设一个进程有 8 个设备,以下代码定义了两个并行计算,一个在头 6 个设备上运行,另一个在剩余的 2 个设备上运行。

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
... def f1(x):
...   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
... def f2(x):
...   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(jnp.arange(6.)))  
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(jnp.array([2., 3.])))  
[ 13.  13.]