全阶段#

mattjj@ 2020 年 9 月 25 日

这更像是一个升级指南,而不是设计文档。

内容#

tl;dr#

发生了什么事?#

JAX 的追踪基础设施中名为“omnistaging” 的一项变更 (google/jax#3370) 已在 jax==0.2.0 中启用。此变更提高了内存性能、追踪执行时间,并简化了 jax 内部,但可能会导致某些现有代码失效。失效通常是由于代码存在错误,因此从长远来看,最好修复这些错误,但也可以将 omnistaging 作为临时解决方法禁用。我们很乐意帮助您解决问题!

如何知道我的代码是否因 omnistaging 而失效?#

判断 omnistaging 是否是原因的最简单方法是禁用 omnistaging,然后观察问题是否消失。请参阅以下部分 启用 omnistaging 时可能出现哪些问题?

如何暂时禁用 omnistaging?#

注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中,无法禁用 omnistaging

可以暂时通过以下方法禁用 omnistaging:

  1. 将 shell 环境变量 JAX_OMNISTAGING 设置为非真值;

  2. 如果您的代码使用 absl 解析标志,则将布尔标志 jax_omnistaging 设置为非真值;

  3. 在主文件开头附近使用以下语句

jax.config.disable_omnistaging()

如何修复 omnistaging 暴露的错误?#

使用 omnistaging 时最常见的问题是使用 jax.numpy 计算形状值或其他追踪时间常量。请参阅以下代码块以获取一个快速示例,并参阅部分 启用 omnistaging 时可能出现哪些问题? 以获取更多详细信息以及其他问题。

不要这样做

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

而是这样做

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

不要将 jax.numpy 视为 numpy 的直接替代品,现在最好将使用 jax.numpy 操作视为仅在您希望对加速器(例如 GPU)执行计算时进行。

什么是“omnistaging” 以及它为何有用?#

Omnistaging 是 JAX 核心升级的名称,旨在将更多计算从逐操作的 Python 迁移到 XLA,并避免在 jitpmap 和控制流基元中进行任何“追踪时间常量折叠”。因此,omnistaging 通过减少追踪过程中的碎片化以及为 XLA 生成更少的较大编译时间常量来提高 JAX 的内存性能(有时显著提高)。它还可以通过消除追踪时间时的逐操作执行来提高追踪性能。此外,omnistaging 简化了 JAX 核心内部,修复了许多未解决的错误,并为即将推出的重要功能奠定了基础。

“omnistaging” 这个名称表示尽可能地进行迁移。

玩具示例#

JAX 变换(如 jitpmap)将计算迁移到 XLA。也就是说,我们将它们应用于包含多个基元操作的函数,以便所有操作不是从 Python 中一次执行一个,而是全部作为一项端到端优化的 XLA 计算的一部分。

但是,究竟哪些操作会被迁移出去呢?在 omnistaging 之前,JAX 基于数据依赖关系来迁移计算。以下是一个示例函数,以及在 omnistaging 更改之前迁移的 XLA HLO 程序

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

请注意,add 操作并未迁移出去。相反,我们只看到一个乘法。

以下是 omnistaging 更改之后从该函数生成的 HLO

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

稍微不太玩具化的示例#

以下是一个不太玩具化的示例,它在实践中出现,当我们想要创建布尔掩码时

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

omnistaging 之前

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

select 操作被迁移出去,但构建常量 mask 的操作并未被迁移出去。构建 mask 的操作不会被迁移出去,而是会在 Python 追踪时间时逐操作执行,而 XLA 只能看到一个表示 mask 值的编译时间常量 constant.1。这很可惜,因为如果我们将构建 mask 的操作迁移出去,XLA 可以将它们与 select 合并,并避免完全物化结果。因此,我们最终会因为一个可能很大的常量而浪费内存,因为调度了多个未合并的逐操作 XLA 计算而浪费时间,甚至可能会导致内存碎片化。

(与构建 jnp.zeros_like(x) 的零数组对应的 broadcast 会被迁移出去,因为 JAX 对 google/jax#1668 中非常简单的表达式很懒惰。omnistaging 之后,我们可以删除该惰性子语言并简化 JAX 内部。)

创建 mask 未被迁移出去的原因是,在 omnistaging 之前,jit 基于数据依赖关系运行。也就是说,jit 仅迁移出函数中对参数具有数据依赖关系的操作。控制流基元和 pmap 的行为类似。在 select_tril 的情况下,构建常量 mask 的操作与参数 x 没有数据依赖关系,因此不会被迁移出去;只有 lax.select 调用具有数据依赖关系。

使用 omnistaging 时,jit 变换函数的动态上下文中所有的 jax.numpy 调用都会被迁移到 XLA。也就是说,omnistaging 之后,XLA 为 select_tril 看到的计算是

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

启用 omnistaging 时可能出现哪些问题?#

由于在 jitpmap 的动态上下文中,将所有 jax.numpy 操作从 Python 迁移到 XLA,因此,一些以前有效的代码可能会开始引发严重的错误。正如以下所述,这些行为在 omnistaging 之前就已经存在错误,但 omnistaging 将它们变成了硬错误。

jax.numpy 用于形状计算#

示例#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

错误消息#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.ac.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

解释#

使用 omnistaging 时,我们不能将 jax.numpy 用于形状计算,如上面对 jnp.prod 的使用,因为在 jit 函数的动态上下文中,这些操作将被迁移出 Python,作为在执行时计算的值,但我们需要它们是编译时间(因此是追踪时间)常量。

在 omnistaging 之前,此代码不会引发错误,但它是一个常见的性能错误:jnp.prod 计算将在追踪时间在设备上执行,这意味着额外的编译、传输、同步、分配以及潜在的内存碎片化。

解决方案#

解决方案是简单地将原始 numpy 用于此类形状计算。我们不仅避免了错误,而且还将计算保留在主机上(并具有更低的开销)。

此问题在代码中很常见,因此我们尝试使错误消息特别好。除了显示抽象追踪值导致问题的堆栈跟踪(堆栈跟踪中的完整堆栈跟踪中的 jnp.reshape 行,在 omni.py:10 上),我们还通过指向导致它成为抽象追踪值的原始基元操作(omni.py:9 上的 reduce_prod 来自 jnp.prod)以及追踪器所属的 jit 装饰函数 (ex1 在 omni.py:6 上) 来解释此值是如何成为追踪器的。

副作用#

示例#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

最后一个调用重复了随机性,但没有硬错误,因为我们没有重新执行 Python。但如果我们看一下 key,我们会看到当 omnistaging 启用时,会产生一个转义追踪器

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

在 omnistaging 之前,random.split 调用不会被迁移出去,因此我们不会获得转义追踪器。代码仍然存在错误,因为 jitted 函数不会复制原始函数的语义(因为重复使用相同的 PRNG 密钥),最终是由于副作用造成的。

omnistaging 启用后,如果我们再次触碰 key,我们会获得一个转义追踪器错误

random.normal(key, ())

错误消息#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

解释#

我们发现的第二大类 omnistaging 问题与副作用代码有关。此代码通过转换有副作用的函数,已经违反了 JAX 的保修,但由于 omnistaging 之前的“追踪时间常量折叠”行为,某些有副作用的函数仍然可以正常运行。omnistaging 会捕获更多此类错误。

解决方案#

解决方案是识别依赖于副作用的 JAX 变换函数,并将它们重写为没有副作用。

基于 XLA 优化的微小数值差异#

由于使用 omnistaging 时,更多计算会被迁移到 XLA,而不是一些在追踪时间执行,因此这可能会导致重新排序浮点运算。因此,我们发现数值行为发生了变化,当 omnistaging 启用时,会导致具有过紧容差的测试失败。

对已更改的 JAX 内部 API 的依赖#

Omnistaging 涉及对 JAX 核心代码进行一些重大修改,包括删除或更改内部函数。任何依赖于此类 JAX 内部 API 的代码在 omnistaging 启用时都可能失效,无论是构建错误(来自 pytype)还是运行时错误。

触发 XLA 编译时间错误#

由于全方位分段涉及将更多代码分段到 XLA,我们发现它在某些后端触发了预先存在的 XLA 编译时错误。对于这些错误,最好的做法是报告它们,以便我们与 XLA 团队合作进行修复。