jax.Array 迁移#
yashkatariya@
太长不看#
从 0.4.1 版本开始,JAX 将其默认数组实现切换到新的 jax.Array
。本指南解释了这样做的原因、它可能对您的代码产生的影响以及如何(暂时)切换回旧行为。
发生了什么?#
jax.Array
是一种统一的数组类型,它包含了 JAX 中的 DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
类型。 jax.Array
类型有助于使并行性成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一 jit 和 pjit。如果您的代码没有提到 DeviceArray
与 ShardedDeviceArray
与 GlobalDeviceArray
的区别,则无需进行任何更改。但是,依赖于这些单独类别的详细信息的代码可能需要进行调整才能与统一的 jax.Array 协同工作。
迁移完成后,jax.Array
将成为 JAX 中唯一的数组类型。
本文档说明了如何将现有代码库迁移到 jax.Array
。有关使用 jax.Array
和 JAX 并行 API 的更多信息,请参阅 分布式数组和自动并行化 教程。
如何启用 jax.Array?#
您可以通过以下方式启用 jax.Array
将 shell 环境变量
JAX_ARRAY
设置为类似 true 的值(例如,1
);如果您的代码使用 absl 解析标志,则将布尔标志
jax_array
设置为类似 true 的值;在主文件顶部使用此语句
import jax jax.config.update('jax_array', True)
如何知道 jax.Array 是否破坏了我的代码?#
确定 jax.Array
是否导致任何问题的最简单方法是禁用 jax.Array
并查看问题是否消失。
如何暂时禁用 jax.Array?#
在 **2023 年 3 月 15 日** 之前,可以通过以下方式禁用 jax.Array:
将 shell 环境变量
JAX_ARRAY
设置为类似 false 的值(例如,0
);如果您的代码使用 absl 解析标志,则将布尔标志
jax_array
设置为类似 false 的值;在主文件顶部使用此语句
import jax jax.config.update('jax_array', False)
为什么要创建 jax.Array?#
目前 JAX 有三种类型;DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
。 jax.Array
合并了这三种类型,并清理了 JAX 的内部结构,同时添加了新的并行功能。
我们还引入了一个新的 Sharding
抽象,它描述了逻辑数组如何在单个或多个设备(例如 TPU 或 GPU)上进行物理分片。此更改还升级、简化和合并了 pjit
的并行功能到 jit
中。使用 jit
装饰的函数将能够在分片数组上运行,而无需将数据复制到单个设备上。
使用 jax.Array
获得的功能
C++
pjit
派发路径逐操作并行(即使数组分布在多个主机上的多个设备上)
使用
pjit
/jit
简化批处理数据并行。创建
Sharding
的方法,这些方法不一定由网格和分区规范组成。如果需要,可以充分利用 OpSharding 的灵活性,或者任何您想要的其他 Sharding。以及更多
示例
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
启用 jax.Array 时可能出现什么问题?#
名为 jax.Array 的新公共类型#
所有 isinstance(..., jnp.DeviceArray)
或 isinstance(.., jax.xla.DeviceArray)
以及 DeviceArray
的其他变体都应切换为使用 isinstance(..., jax.Array)
。
由于 jax.Array
可以表示 DA、SDA 和 GDA,因此可以通过以下方式在 jax.Array
中区分这三种类型:
x.is_fully_addressable and len(x.sharding.device_set) == 1
– 这表示jax.Array
类似于 DAx.is_fully_addressable and (len(x.sharding.device_set) > 1
– 这表示jax.Array
类似于 SDAnot x.is_fully_addressable
– 这表示jax.Array
类似于 GDA 并跨越多个进程
对于 ShardedDeviceArray
,您可以将 isinstance(..., pxla.ShardedDeviceArray)
移动到 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
。
通常无法将 1 个设备上的 ShardedDeviceArray
与任何其他类型的单设备数组区分开来。
GDA 的 API 名称更改#
GDA 的 local_shards
和 local_data
已弃用。
请使用 addressable_shards
和 addressable_data
,它们与 jax.Array
和 GDA
兼容。
创建 jax.Array#
当 jax_array
标志为 True 时,所有 JAX 函数都会输出 jax.Array
。如果您以前使用 GlobalDeviceArray.from_callback
或 make_sharded_device_array
或 make_device_array
函数显式创建相应的 JAX 数据类型,则需要将它们切换为使用 jax.make_array_from_callback()
或 jax.make_array_from_single_device_arrays()
。
对于 GDA
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
可以变为 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
,进行 1:1 切换。
如果您使用原始 GDA 构造函数创建 GDA,则执行以下操作
GlobalDeviceArray(shape, mesh, pspec, buffers)
可以变为 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
对于 SDA
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
可以变为 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
。
要确定 sharding 应该是什么,取决于您创建 SDA 的原因
如果它是为了作为 pmap
的输入创建的,则 sharding 可以是:jax.sharding.PmapSharding(devices, sharding_spec)
。
如果它是为了作为 pjit
的输入创建的,则 sharding 可以是 jax.sharding.NamedSharding(mesh, pspec)
。
将主机本地输入切换到 jax.Array 后,pjit 的重大更改#
如果您专门使用 GDA 参数传递给 pjit,则可以跳过此部分!🎉
启用 jax.Array
后,传递给 pjit
的所有输入都必须是全局形状。这与之前的行为存在重大差异,在之前的行为中,pjit
会将进程本地参数连接成全局值;这种连接不再发生。
我们为什么要进行这种重大更改?每个数组现在都明确说明其本地分片如何适应全局整体,而不是将其隐式地保留下来。更明确的表示还提供了额外的灵活性,例如在 pjit
中使用非连续网格,这可以提高某些 TPU 模型的效率。
运行 **多进程 pjit 计算** 并在启用 jax.Array
时传递主机本地输入可能会导致类似以下错误
示例
网格 = {'x': 2, 'y': 2, 'z': 2}
和主机本地输入形状 == (4,)
以及 pspec = P(('x', 'y', 'z'))
由于 pjit
不会使用 jax.Array
将主机本地形状提升为全局形状,因此您会收到以下错误
注意:只有当主机本地形状小于网格形状时,才会看到此错误。
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
此错误是有道理的,因为当维度 0
上的值为 4
时,您无法将维度 0 分片 8 次。
如果您仍然将主机本地输入传递给 pjit
,如何进行迁移?我们提供了一些过渡性 API 来帮助您进行迁移
注意:如果您在一个进程上运行 pjit 计算,则不需要这些实用程序。
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
host_local_array_to_global_array
是一种类型转换,它查看仅具有本地分片的值,并将它的本地形状更改为 pjit
在更改之前假设的值的形状。
传递完全复制的输入,即每个进程上的形状相同,并使用 P(None)
作为 in_axis_resources
仍然受支持。在这种情况下,您不必使用 host_local_array_to_global_array
,因为形状已经是全局的。
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
FROM_GDA 和 jax.Array#
如果您之前在 pjit
函数的 in_axis_resources
参数中使用了 FROM_GDA
,那么使用 jax.Array
时,无需向 in_axis_resources
传递任何内容,因为 jax.Array
将遵循**计算跟随分片**语义。
例如
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
如果您在输入(例如 NumPy 数组等)中混合使用了 PartitionSpecs 和 FROM_GDA
,则可以使用 host_local_array_to_global_array
将它们转换为 jax.Array
。
例如
如果您之前是这样做的
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
那么您可以将其替换为
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
live_buffers 被 live_arrays 替换#
jax Device
上的 live_buffers
属性已弃用。请改用 jax.live_arrays()
,它与 jax.Array
兼容。
处理传递给 pjit 的主机本地输入(如批次等)#
如果您在**多进程环境**中将主机本地输入传递给 pjit
,请使用 multihost_utils.host_local_array_to_global_array
将批次转换为全局 jax.Array
,然后将其传递给 pjit
。
此类主机本地输入最常见的示例是**输入数据的批次**。
这适用于任何主机本地输入(不仅仅是输入数据的批次)。
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
有关此更改和更多示例的详细信息,请参阅上面的 pjit 部分。
RecursionError: 递归调用 jit#
当代码的某些部分禁用了 jax.Array
,然后只为其他部分启用了它时,就会发生这种情况。例如,如果您使用了一些禁用了 jax.Array
的第三方代码,并从该库中获取了 DeviceArray
,然后在您的库中启用了 jax.Array
并将该 DeviceArray
传递给 JAX 函数,这将导致 RecursionError。
当 jax.Array
默认启用时,此错误应该会消失,这样所有库都会返回 jax.Array
,除非它们明确禁用它。