jax.Array 迁移#

yashkatariya@

TL;DR#

从 0.4.1 版本开始,JAX 将其默认数组实现切换为新的 jax.Array。 本指南解释了此背后的原因、它可能对您的代码产生的影响,以及如何(暂时)切换回旧行为。

发生了什么?#

jax.Array 是一种统一的数组类型,它取代了 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 类型。jax.Array 类型有助于使并行性成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一 jit 和 pjit。如果你的代码没有提及 DeviceArrayShardedDeviceArrayGlobalDeviceArray,则无需进行任何更改。但是,依赖于这些单独类别的细节的代码可能需要进行调整才能与统一的 jax.Array 一起使用。

迁移完成后,jax.Array 将成为 JAX 中唯一的数组类型。

本文档解释了如何将现有代码库迁移到 jax.Array。有关使用 jax.Array 和 JAX 并行 API 的更多信息,请参阅 分布式数组和自动并行化 教程。

如何启用 jax.Array?#

你可以通过以下方式启用 jax.Array

  • 将 shell 环境变量 JAX_ARRAY 设置为类似 true 的值(例如,1);

  • 如果你的代码使用 absl 解析标志,则将布尔标志 jax_array 设置为类似 true 的值;

  • 在你的主文件的顶部使用此语句

    import jax
    jax.config.update('jax_array', True)
    

如何知道 jax.Array 是否破坏了我的代码?#

判断任何问题是否由 jax.Array 引起的,最简单的方法是禁用 jax.Array,然后查看问题是否消失。

现在如何禁用 jax.Array?#

在 **2023 年 3 月 15 日** 之前,可以通过以下方式禁用 jax.Array:

  • 将 shell 环境变量 JAX_ARRAY 设置为类似 false 的值(例如,0);

  • 如果你的代码使用 absl 解析标志,则将布尔标志 jax_array 设置为类似 false 的值;

  • 在你的主文件的顶部使用此语句

    import jax
    jax.config.update('jax_array', False)
    

为什么要创建 jax.Array?#

目前,JAX 有三种类型:DeviceArrayShardedDeviceArrayGlobalDeviceArrayjax.Array 合并了这三种类型,并在添加新的并行功能的同时清理了 JAX 的内部结构。

我们还引入了一个新的 Sharding 抽象,它描述了逻辑数组如何在诸如 TPU 或 GPU 等一个或多个设备上进行物理分片。此更改还升级、简化和合并了 pjit 的并行功能到 jit 中。使用 jit 修饰的函数将能够对分片数组进行操作,而无需将数据复制到单个设备上。

使用 jax.Array 获得的功能

  • C++ pjit 调度路径

  • 逐操作并行(即使数组分布在多个主机上的多个设备之间)

  • 使用 pjit/jit 实现更简单的批处理数据并行。

  • 创建不一定由网格和分区规范组成的 Sharding 的方法。如果需要,可以充分利用 OpSharding 的灵活性,或任何其他你想要的 Sharding。

  • 还有更多

示例

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)

# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

sharded_x = jax.device_put(x, sharding)

# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)

# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)

# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)

启用 jax.Array 时可能会出现哪些问题?#

名为 jax.Array 的新公共类型#

所有 isinstance(..., jnp.DeviceArray)isinstance(.., jax.xla.DeviceArray) 以及 DeviceArray 的其他变体都应切换为使用 isinstance(..., jax.Array)

由于 jax.Array 可以表示 DA、SDA 和 GDA,你可以通过以下方式在 jax.Array 中区分这 3 种类型:

  • x.is_fully_addressable and len(x.sharding.device_set) == 1 – 这意味着 jax.Array 类似于 DA

  • x.is_fully_addressable and (len(x.sharding.device_set) > 1 – 这意味着 jax.Array 类似于 SDA

  • not x.is_fully_addressable – 这意味着 jax.Array 类似于 GDA,并且跨多个进程

对于 ShardedDeviceArray,你可以将 isinstance(..., pxla.ShardedDeviceArray) 移动到 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1

一般来说,不可能将 1 个设备上的 ShardedDeviceArray 与任何其他类型的单设备数组区分开来。

GDA 的 API 名称更改#

GDA 的 local_shardslocal_data 已被弃用。

请使用与 jax.ArrayGDA 兼容的 addressable_shardsaddressable_data

创建 jax.Array#

jax_array 标志为 True 时,所有 JAX 函数都将输出 jax.Array。如果你正在使用 GlobalDeviceArray.from_callbackmake_sharded_device_arraymake_device_array 函数来显式创建相应的 JAX 数据类型,你需要将它们切换为使用 jax.make_array_from_callback()jax.make_array_from_single_device_arrays()

对于 GDA

GlobalDeviceArray.from_callback(shape, mesh, pspec, callback) 可以 1:1 切换为 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)

如果你正在使用原始 GDA 构造函数来创建 GDA,则执行此操作

GlobalDeviceArray(shape, mesh, pspec, buffers) 可以变为 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)

对于 SDA

make_sharded_device_array(aval, sharding_spec, device_buffers, indices) 可以变为 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)

为了决定 sharding 应该是什么,这取决于你创建 SDA 的原因

如果它是为了作为 pmap 的输入而创建的,那么 sharding 可以是:jax.sharding.PmapSharding(devices, sharding_spec)

如果它是为了作为 pjit 的输入而创建的,那么 sharding 可以是 jax.sharding.NamedSharding(mesh, pspec)

在为主机本地输入切换到 jax.Array 后,pjit 的重大更改#

如果你完全使用 GDA 参数作为 pjit 的输入,你可以跳过此部分!🎉

启用 jax.Array 后,pjit 的所有输入都必须是全局形状的。 这与之前的行为有所不同,之前的行为是 pjit 会将进程本地参数连接成一个全局值; 现在不再进行这种连接。

为什么要进行此重大更改?现在每个数组都明确说明了其本地分片如何适应全局整体,而不是隐式地表示。 这种更明确的表示还解锁了额外的灵活性,例如,使用非连续网格与 pjit,这可以提高某些 TPU 模型上的效率。

当启用 jax.Array 时,运行多进程 pjit 计算并传递主机本地输入可能会导致类似于以下的错误

示例

网格 = {'x': 2, 'y': 2, 'z': 2},主机本地输入形状 == (4,),pspec = P(('x', 'y', 'z'))

由于 pjit 不会使用 jax.Array 将主机本地形状提升为全局形状,因此您会收到以下错误

注意:只有当您的主机本地形状小于网格的形状时,才会看到此错误。

ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4

这个错误是有道理的,因为当维度 0 上的值为 4 时,您不能将维度 0 分片 8 次。

如果您仍然将主机本地输入传递给 pjit,该如何迁移?我们提供了过渡 API 来帮助您迁移

注意:如果您在单个进程上运行 pjitted 计算,则不需要这些实用程序。

from jax.experimental import multihost_utils

global_inps = multihost_utils.host_local_array_to_global_array(
    local_inputs, mesh, in_pspecs)

global_outputs = pjit(f, in_shardings=in_pspecs,
                      out_shardings=out_pspecs)(global_inps)

local_outs = multihost_utils.global_array_to_host_local_array(
    global_outputs, mesh, out_pspecs)

host_local_array_to_global_array 是一种类型转换,它查看仅具有本地分片的值,并将其本地形状更改为 pjit 在更改之前传递该值时会假设的形状。

仍然支持传入完全复制的输入,即每个进程上的形状相同,且 in_axis_resourcesP(None)。 在这种情况下,您不必使用 host_local_array_to_global_array,因为形状已经是全局的。

key = jax.random.PRNGKey(1)

# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)

# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
    local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
                  out_shardings=...)(key, global_inp)

FROM_GDA 和 jax.Array#

如果您在 pjitin_axis_resources 参数中使用 FROM_GDA,那么使用 jax.Array 时,无需向 in_axis_resources 传递任何内容,因为 jax.Array 将遵循计算遵循分片的语义。

例如

pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)

如果您的输入(如 numpy 数组等)中混合使用了 PartitionSpecs 和 FROM_GDA,则可以使用 host_local_array_to_global_array 将它们转换为 jax.Array

例如

如果您之前有以下代码

pjitted_f = pjit(
    f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
    out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)

那么您可以将其替换为


pjitted_f = pjit(f, out_shardings=...)

array2, array3 = multihost_utils.host_local_array_to_global_array(
    (np_array1, np_array2), mesh, (P('x'), P(None)))

pjitted_f(array1, array2, array3, array4)

live_buffers 替换为 live_arrays#

jax Device 上的 live_buffers 属性已弃用。请改用与 jax.Array 兼容的 jax.live_arrays()

处理 pjit 的主机本地输入(如批处理等)#

如果您在多进程环境中将主机本地输入传递给 pjit,则请使用 multihost_utils.host_local_array_to_global_array 将批处理转换为全局 jax.Array,然后将其传递给 pjit

这种主机本地输入最常见的示例是一批输入数据

这将适用于任何主机本地输入(不仅仅是一批输入数据)。

from jax.experimental import multihost_utils

batch = multihost_utils.host_local_array_to_global_array(
    batch, mesh, batch_partition_spec)

有关此更改和更多示例的详细信息,请参阅上面的 pjit 部分。

RecursionError:递归调用 jit#

当您的代码的某些部分禁用了 jax.Array,然后您仅在其他某些部分启用它时,会发生这种情况。 例如,如果您使用一些禁用了 jax.Array 的第三方代码,并且从该库中获取了一个 DeviceArray,然后在您的库中启用 jax.Array 并将该 DeviceArray 传递给 JAX 函数,则会导致 RecursionError。

当默认启用 jax.Array 时,此错误应该会消失,以便所有库都返回 jax.Array,除非它们明确禁用它。