在 JAX 中编写自定义 Jaxpr 解释器#

Open in Colab Open in Kaggle

JAX 提供了几种可组合的函数转换(jitgradvmap 等),可以编写简洁且加速的代码。

在这里,我们将展示如何通过编写自定义 Jaxpr 解释器向系统添加您自己的函数转换。我们将免费获得与其他所有转换的可组合性。

此示例使用内部 JAX API,这些 API 可能随时中断。任何不在 API 文档 中的内容都应视为内部内容。

import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

JAX 在做什么?#

JAX 提供了一个类似 NumPy 的数值计算 API,可以按原样使用,但 JAX 的真正强大之处在于可组合的函数转换。以 jit 函数转换为例,它接受一个函数并返回一个语义相同的函数,但会被 XLA 惰性编译以用于加速器。

x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)

当我们调用 fast_f 时会发生什么?JAX 会跟踪该函数并构建一个 XLA 计算图。然后,该图会被 JIT 编译并执行。其他转换的工作方式类似,它们首先跟踪函数,然后以某种方式处理输出跟踪。要了解有关 JAX 跟踪机制的更多信息,您可以参考 README 中的 “工作原理” 部分。

Jaxpr 跟踪器#

在 JAX 中,一个特别重要的跟踪器是 Jaxpr 跟踪器,它将操作记录到 Jaxpr(Jax 表达式)中。Jaxpr 是一种数据结构,可以像微型函数式编程语言一样进行评估,因此 Jaxpr 是函数转换的有用中间表示。

要初步了解 Jaxpr,请考虑 make_jaxpr 转换。make_jaxpr 本质上是一种“漂亮打印”转换:它将一个函数转换为一个函数,该函数在给定示例参数的情况下,生成其计算的 Jaxpr 表示形式。make_jaxpr 对于调试和内省很有用。让我们使用它来看看一些示例 Jaxpr 的结构。

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=139806762118336):int32[]]
outvars: [Var(id=139806762120064):int32[]]
constvars: []
equation: [Var(id=139806762118336):int32[], 1] add [Var(id=139806762120064):int32[]] {}

jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

bar
=====
invars: [Var(id=139806762348928):float32[5,10], Var(id=139806762349248):float32[5], Var(id=139806762349312):float32[10]]
outvars: [Var(id=139806762349568):float32[5], Var(id=139806762349312):float32[10]]
constvars: []
equation: [Var(id=139806762348928):float32[5,10], Var(id=139806762349312):float32[10]] dot_general [Var(id=139806762349376):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32'), 'out_type': None}
equation: [Var(id=139806762349376):float32[5], Var(id=139806762349248):float32[5]] add [Var(id=139806762349440):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=139806762349504):float32[5]] {'shape': (5,), 'broadcast_dimensions': (), 'sharding': None}
equation: [Var(id=139806762349440):float32[5], Var(id=139806762349504):float32[5]] add [Var(id=139806762349568):float32[5]] {}

jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(5,)
      sharding=None
    ] 1.0
    g:f32[5] = add e f
  in (g, c) }
  • jaxpr.invars - Jaxpr 的 invars 是 Jaxpr 的输入变量列表,类似于 Python 函数中的参数。

  • jaxpr.outvars - Jaxpr 的 outvars 是 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。

  • jaxpr.constvars - constvars 是一个变量列表,这些变量也是 Jaxpr 的输入,但对应于跟踪中的常量(稍后我们将详细介绍)。

  • jaxpr.eqns - 一个方程列表,本质上是 let 绑定。每个方程都是一个输入变量列表,一个输出变量列表,以及一个原始操作,用于评估输入以产生输出。每个方程还有一个 params,一个参数字典。

总而言之,Jaxpr 封装了一个简单的程序,可以使用输入进行评估以产生输出。稍后我们将详细介绍如何执行此操作。现在需要注意的重要一点是,Jaxpr 是一种数据结构,可以以我们想要的任何方式进行操作和评估。

为什么 Jaxpr 有用?#

Jaxpr 是易于转换的简单程序表示形式。而且,由于 JAX 允许我们从 Python 函数中分阶段输出 Jaxpr,它为我们提供了一种转换用 Python 编写的数值程序的方法。

您的第一个解释器:invert#

让我们尝试实现一个简单的函数“反转器”,它接收原始函数的输出并返回生成这些输出的输入。现在,让我们专注于由其他可逆一元函数组成的简单的一元函数。

目标

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

我们将通过以下方式实现这一点:(1)将 f 跟踪到 Jaxpr 中,然后(2)向后解释 Jaxpr。在向后解释 Jaxpr 时,对于每个方程,我们将在表中查找原始操作的逆操作并应用它。

1. 跟踪函数#

让我们使用 make_jaxpr 将函数跟踪到 Jaxpr 中。

# Importing Jax functions useful for tracing/interpreting.
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map

jax.make_jaxpr 返回一个封闭的 Jaxpr,它是一个与跟踪中的常量 (literals) 捆绑在一起的 Jaxpr。

def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]

2. 评估 Jaxpr#

在我们编写自定义 Jaxpr 解释器之前,让我们首先实现“默认”解释器 eval_jaxpr,它按原样评估 Jaxpr,计算原始、未转换的 Python 函数将计算的相同值。

为此,我们首先创建一个环境来存储每个变量的值,并使用我们在 Jaxpr 中评估的每个方程更新该环境。

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}

  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results:
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals)
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
/tmp/ipykernel_1211/3734673940.py:7: DeprecationWarning: jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, and see https://jax.ac.cn/en/latest/jax.extend.html for details.
  if type(var) is core.Literal:
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

请注意,即使原始函数没有,eval_jaxpr 也将始终返回一个扁平列表。

此外,此解释器不处理高阶原始操作(如 jitpmap),我们将在本指南中不对此进行介绍。您可以参考 core.eval_jaxpr (链接) 以查看此解释器未涵盖的边缘情况。

自定义 inverse Jaxpr 解释器#

inverse 解释器与 eval_jaxpr 看起来没有太大区别。我们首先设置注册表,该注册表会将原始操作映射到它们的逆操作。然后,我们将编写一个自定义解释器,该解释器将在注册表中查找原始操作。

事实证明,此解释器也与反向模式自动微分中使用的“转置”解释器 类似

inverse_registry = {}

现在,我们将为一些原始操作注册逆操作。按照惯例,JAX 中的原始操作以 _p 结尾,许多流行的操作位于 lax 中。

inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

inverse 将首先跟踪函数,然后自定义解释 Jaxpr。让我们设置一个简单的框架。

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    # Since we assume unary functions, we won't worry about flattening and
    # unflattening arguments.
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped

现在,我们只需要定义 inverse_jaxpr,它将向后遍历 Jaxpr,并在可能的情况下反转原始操作。

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}

  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val
  # Args now correspond to Jaxpr outvars
  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Looping backward
  for eqn in jaxpr.eqns[::-1]:
    #  outvars are now invars
    invals = safe_map(read, eqn.outvars)
    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    # Assuming a unary function
    outval = inverse_registry[eqn.primitive](*invals)
    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars)

就是这样!

def f(x):
  return jnp.exp(jnp.tanh(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

重要的是,您可以跟踪 Jaxpr 解释器。

jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }

这就是向系统添加新转换所需的一切,并且您可以免费获得与所有其他转换的组合!例如,我们可以将 jitvmapgradinverse 一起使用!

jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1.       ],      dtype=float32, weak_type=True)

给读者的练习#

  • 处理具有多个参数的原始操作,其中输入是部分已知的,例如 lax.add_plax.mul_p

  • 处理 xla_callxla_pmap 原始操作,这些操作将无法使用编写的 eval_jaxprinverse_jaxpr