全域分段#

mattjj@ 2020 年 9 月 25 日

这更像是一个升级指南,而不是一个设计文档。

目录#

概括#

发生了什么?#

JAX 的跟踪基础设施的一项名为“全域分段”(omnistaging)的更改 (jax-ml/jax#3370) 已在 jax==0.2.0 中启用。此更改提高了内存性能、跟踪执行时间并简化了 jax 内部结构,但可能会导致一些现有代码中断。中断通常是错误代码的结果,因此从长远来看,最好修复这些错误,但也可以禁用全域分段作为临时解决方法。我们很乐意帮助您修复!

如何知道全域分段是否破坏了我的代码?#

判断全域分段是否是罪魁祸首的最简单方法是禁用全域分段,看看问题是否消失。请参阅下面的启用全域分段时可能出现哪些问题?部分。

如何暂时禁用全域分段?#

注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中无法禁用全域分段

可以通过以下方式暂时禁用全域分段

  1. 将 shell 环境变量 JAX_OMNISTAGING 设置为假值;

  2. 如果您的代码使用 absl 解析标志,则将布尔标志 jax_omnistaging 设置为假值;

  3. 在主文件的顶部附近使用此语句

jax.config.disable_omnistaging()

如何修复全域分段暴露的错误?#

到目前为止,全域分段最常见的问题是使用 jax.numpy 来计算形状值或其他跟踪时常量。请参见下面的代码块以获取快速示例,并查看启用全域分段时可能出现哪些问题?部分以了解完整详细信息和其他问题。

不要这样做

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

而要这样做

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

不要将 jax.numpy 视为 numpy 的直接替代品,现在最好认为仅当您想在加速器(如 GPU)上执行计算时才使用 jax.numpy 操作。

什么是“全域分段”,为什么它很有用?#

全域分段是 JAX 核心升级的名称,旨在将更多计算从逐个操作的 Python 分段到 XLA,并避免 jitpmap 和控制流原语中的任何“跟踪时常量折叠”。因此,全域分段通过减少跟踪期间的碎片化和为 XLA 生成更少的编译时常量来提高 JAX 的内存性能(有时会显著提高)。它还可以通过消除跟踪时的逐个操作执行来提高跟踪性能。此外,全域分段简化了 JAX 核心内部结构,修复了许多未解决的错误,并为即将到来的重要功能奠定了基础。

“全域分段”的名称意味着分段出所有可能的事物。

玩具示例#

jitpmap 这样的 JAX 转换会将计算分段到 XLA。也就是说,我们将它们应用于包含多个原始操作的函数,以便这些操作不是从 Python 中逐个执行,而是都成为端到端优化的 XLA 计算的一部分。

但是具体分段哪些操作呢?在全域分段之前,JAX 仅基于数据依赖关系来分段计算。以下是一个示例函数,后面是它在全域分段更改之前分段出的 XLA HLO 程序

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

请注意,add 操作没有被分段。相反,我们只看到一个乘法。

以下是此函数在全域分段更改之后生成的 HLO

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

稍微不那么玩具的示例#

以下是一个在实践中可能出现的稍微不那么玩具的示例,当我们想要创建布尔掩码时

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

全域分段之前

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

select 操作被分段了,但是构造常量 mask 的操作没有被分段。构造 mask 的操作不是被分段,而是在 Python 跟踪时逐个操作执行,XLA 只看到一个表示 mask 值的编译时常量 constant.1。这很不幸,因为如果我们分段了构造 mask 的操作,XLA 本可以将其融合到 select 中并完全避免实现结果。结果,我们最终浪费了可能很大的常量的内存,浪费了分派多个未融合的逐个操作 XLA 计算的时间,甚至可能使内存碎片化。

(与为 jnp.zeros_like(x) 构造零数组相对应的 broadcast 已被分段,因为 JAX 对于来自 jax-ml/jax#1668 的非常简单的表达式是懒惰的。在全域分段之后,我们可以删除该懒惰的子语言并简化 JAX 内部结构。)

未分段 mask 创建的原因是,在全域分段之前,jit 是基于数据依赖关系运行的。也就是说,jit 仅分段函数中那些具有对参数的数据依赖关系的操作。控制流原语和 pmap 的行为类似。对于 select_tril,构造常量 mask 的操作不具有对参数 x 的数据依赖关系,因此它们不会被分段;只有 lax.select 调用具有数据依赖关系。

通过全域分段,jit 转换函数的动态上下文中所有 jax.numpy 调用都将分段到 XLA。也就是说,在全域分段之后,XLA 为 select_tril 看到的计算是

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

启用全域分段时可能出现哪些问题?#

由于在 jitpmap 的动态上下文中,将所有 jax.numpy 操作从 Python 分段到 XLA,因此一些以前可以工作的代码可能会开始引发响亮的错误。如下所述,这些行为在全域分段之前已经存在错误,但全域分段使其成为硬错误。

使用 jax.numpy 进行形状计算#

示例#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

错误消息#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.ac.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

解释#

通过全域分段,我们无法像上面使用 jnp.prod 那样使用 jax.numpy 进行形状计算,因为在 jit 函数的动态上下文中,这些操作将从 Python 中分段出来作为在执行时计算的值,但是我们需要它们成为编译时(因此也是跟踪时)常量。

在全域分段之前,此代码不会引发错误,但这是一种常见的性能错误:jnp.prod 计算将在跟踪时在设备上执行,这意味着额外的编译、传输、同步、分配以及可能的内存碎片。

解决方案#

解决方案很简单,就是使用原始的 numpy 来进行此类形状计算。我们不仅避免了错误,而且还将计算保留在主机上(并降低了开销)。

此问题在代码中非常普遍,以至于我们试图使错误消息特别好。除了显示抽象跟踪器值导致问题的位置的堆栈跟踪(全堆栈跟踪中的 jnp.reshape 行,在 omni.py:10 上),我们还解释了为什么此值首先成为跟踪器,方法是指向导致它成为抽象跟踪器的上游原始操作(来自 omni.py:9 上 jnp.prodreduce_prod)和跟踪器所属的 jit 装饰函数(omni.py:6 上的 ex1)。

副作用#

示例#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

最后一次调用重复了随机性,但没有硬错误,因为我们没有重新执行 Python。但是,如果我们查看 key,我们会看到一个转义的跟踪器当启用全暂存时

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

在全暂存之前,random.split 调用不会被暂存出去,因此我们不会得到一个转义的跟踪器。代码仍然存在缺陷,因为 JIT 编译的函数不会重现原始函数的语义(由于重复使用相同的 PRNG 密钥),最终是由于副作用造成的。

启用全暂存后,如果我们再次访问 key,我们将收到一个转义的跟踪器错误

random.normal(key, ())

错误消息#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

解释#

我们发现的第二大类全暂存问题与具有副作用的代码有关。这段代码通过转换有副作用的函数已经违反了 JAX 的保证,但是由于全暂存之前的“跟踪时常量折叠”行为,一些有副作用的函数仍然可以正确运行。全暂存可以捕获更多此类错误。

解决方案#

解决方案是识别依赖副作用的 JAX 转换函数,并将它们重写为无副作用的函数。

基于 XLA 优化的小数值差异#

因为启用全暂存后,更多的计算被暂存到 XLA,而不是某些计算在跟踪时执行,这可能会导致浮点运算的重新排序。因此,我们看到数值行为发生了变化,这会导致在启用全暂存时,容差过紧的测试失败。

依赖已更改的 JAX 内部 API#

全暂存涉及到对 JAX 核心代码的一些重大修改,包括删除或更改内部函数。任何依赖此类 JAX 内部 API 的代码在启用全暂存后都可能崩溃,要么出现构建错误(来自 pytype),要么出现运行时错误。

触发 XLA 编译时错误#

因为全暂存涉及到将更多代码暂存到 XLA,我们看到它在某些后端上触发了预先存在的 XLA 编译时错误。处理这些问题的最好方法是报告它们,以便我们可以与 XLA 团队合作进行修复。