理解 Jaxpr#
更新:2020 年 5 月 3 日(针对提交 f1a46fe)。
从概念上讲,可以将 JAX 变换视为首先对要变换的 Python 函数进行跟踪特化,将其转换为一种小型且行为良好的中间形式,然后使用特定于变换的解释规则对其进行解释。JAX 能够在一个如此小的软件包中包含如此强大的功能的原因之一是,它从一个熟悉且灵活的编程接口(带有 NumPy 的 Python)开始,并且它使用实际的 Python 解释器来完成大部分繁重的工作,将计算的本质提炼成一个简单的静态类型表达式语言,该语言具有有限的高阶特性。这种语言就是 jaxpr 语言。
并非所有 Python 程序都可以通过这种方式进行处理,但事实证明,许多科学计算和机器学习程序都可以。
在继续之前,必须指出,并非所有 JAX 变换都像上面描述的那样真正地实现了一个 jaxpr;有些变换,例如微分或批处理,会在跟踪过程中增量地应用变换。然而,如果想要了解 JAX 的内部工作原理,或者利用 JAX 跟踪的结果,那么理解 jaxpr 非常有用。
Jaxpr 实例表示一个函数,该函数具有一个或多个类型化的参数(输入变量)和一个或多个类型化的结果。结果仅取决于输入变量;没有从封闭作用域捕获的自由变量。输入和输出具有类型,在 JAX 中,这些类型表示为抽象值。代码中用于 jaxpr 的两种相关表示是 jax.core.Jaxpr
和 jax.core.ClosedJaxpr
。 jax.core.ClosedJaxpr
表示一个部分应用的 jax.core.Jaxpr
,并且是使用 jax.make_jaxpr()
检查 jaxpr 时获得的结果。它具有以下字段
jaxpr
是一个表示函数实际计算内容的jax.core.Jaxpr
(如下所述)。
consts
是一个常量列表。
ClosedJaxpr 最有趣的部分是其实际执行内容,它以 jax.core.Jaxpr
的形式表示,并使用以下语法打印
Jaxpr ::= { lambda Var* ; Var+. let
Eqn*
in [Expr+] }
- 其中
jaxpr 的参数显示为两个变量列表,用
;
分隔。第一组变量是为已提升出的常量引入的变量。这些被称为constvars
,在jax.core.ClosedJaxpr
中,consts
字段保存对应的值。第二个变量列表称为invars
,对应于被跟踪的 Python 函数的输入。Eqn*
是一个方程列表,定义了引用中间表达式的中间变量。每个方程将一个或多个变量定义为在某些原子表达式上应用原语的结果。每个方程仅使用输入变量和先前方程定义的中间变量。Expr+
:是 jaxpr 的输出原子表达式(字面量或变量)列表。
方程如下打印
Eqn ::= Var+ = Primitive [ Param* ] Expr+
- 其中
Var+
是一个或多个中间变量,被定义为原语调用(某些原语可以返回多个值)的输出。Expr+
是一个或多个原子表达式,每个表达式要么是变量,要么是字面常量。一个特殊的变量unitvar
或字面量unit
,打印为*
,表示在后续计算中不需要的值,并且已被省略。也就是说,单位只是占位符。Param*
是原语的零个或多个命名参数,用方括号打印。每个参数显示为Name = Value
。
大多数 jaxpr 原语是一阶的(它们只接受一个或多个 Expr
作为参数)
Primitive := add | sub | sin | mul | ...
jaxpr 原语在 jax.lax
模块中进行了文档说明。
例如,以下是为以下函数 func1
生成的 jaxpr
>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
这里没有 constvars,a
和 b
是输入变量,它们分别对应于 first
和 second
函数参数。标量字面量 3.0
保持内联。除了操作数 e
之外,reduce_sum
原语还具有命名参数 axes
。
请注意,即使调用 JAX 的程序的执行构建了 jaxpr,Python 级别的控制流和 Python 级别的函数也会正常执行。这意味着,仅仅因为 Python 程序包含函数和控制流,生成的 jaxpr 不必包含控制流或高阶特性。
例如,在跟踪函数 func3
时,JAX 将内联对 inner
的调用以及条件 if second.shape[0] > 4
,并将生成与之前相同的 jaxpr
>>> def func2(inner, first, second):
... temp = first + inner(second) * 3.
... return jnp.sum(temp)
...
>>> def inner(second):
... if second.shape[0] > 4:
... return jnp.sin(second)
... else:
... assert False
...
>>> def func3(first, second):
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
处理 PyTrees#
在 jaxpr 中没有元组类型;相反,原语接受多个输入并产生多个输出。在处理具有结构化输入或输出的函数时,JAX 将对其进行扁平化,并且在 jaxpr 中它们将显示为输入和输出列表。有关更多详细信息,请参阅 PyTrees 的文档(Pytrees)。
例如,以下代码生成与之前看到的相同的 jaxpr(有两个输入变量,每个输入元组元素一个)
>>> def func4(arg): # Arg is a pair
... temp = arg[0] + jnp.sin(arg[1]) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
常量变量#
jaxprs 中的一些值是常量,因为它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示;非标量数组常量则提升到顶层 jaxpr,在那里它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)的区别仅在于记账约定。
高阶原语#
jaxpr 包含几个高阶原语。它们比较复杂,因为它们包含子 jaxprs。
条件语句#
JAX 跟踪正常的 Python 条件语句。要捕获用于动态执行的条件表达式,必须使用 jax.lax.switch()
和 jax.lax.cond()
构造函数,它们具有以下签名
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
这两者都将在内部绑定一个名为 cond
的原语。cond
原语在 jaxprs 中反映了 lax.switch()
的更通用签名:它接受一个整数,表示要执行的分支的索引(钳位到有效的索引范围内)。
例如
>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
... return lax.switch(index, [lambda x: x + 1.,
... lambda x: x - 2.,
... lambda x: x + 3.],
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
)
] d b
in (e,) }
branches 参数对应于分支函数。在此示例中,这些函数分别接受一个输入变量,对应于 x
。
上述 cond
原语实例接受两个操作数。第一个(d
)是分支索引,然后 b
是要传递给 branches
中由分支索引选择哪个 jaxpr 的操作数(arg
)。
另一个示例,使用 lax.cond()
>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... lambda xtrue: xtrue + 3.,
... lambda xfalse: xfalse - 3.,
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
{ lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
)
] c a
in (d,) }
在这种情况下,布尔谓词将转换为整数索引(0 或 1),branches
是对应于 false 和 true 分支函数的 jaxprs,按此顺序。同样,每个函数都接受一个输入变量,分别对应于 xfalse
和 xtrue
。
以下示例显示了一种更复杂的情况,其中分支函数的输入是元组,并且 false 分支函数包含一个常量 jnp.ones(1)
,该常量被提升为 constvar
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
e:bool[] = ge b 0.0
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
g:f32[1] = cond[
branches=(
{ lambda ; h:i32[1] i:f32[1] j:f32[]. let
k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
l:f32[1] = add k j
in (l,) }
{ lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) }
)
] f a c d
in (g,) }
While循环#
就像条件语句一样,Python 循环在跟踪期间会被内联。如果要捕获用于动态执行的循环,则必须使用以下几种特殊操作之一:jax.lax.while_loop()
(一个原语)和 jax.lax.fori_loop()
(一个生成 while_loop 原语的帮助函数)
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
在上述签名中,“C”代表循环“携带”值的类型。例如,以下是一个 fori 循环示例
>>> import numpy as np
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
... return lax.fori_loop(0, n,
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[16] = add a c
_:i32[] _:i32[] e:f32[16] = while[
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
k:i32[] = add h 1
l:f32[16] = mul f 3.0
m:f32[16] = add j l
n:f32[16] = add m g
in (k, i, n) }
body_nconsts=2
cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
r:bool[] = lt o p
in (r,) }
cond_nconsts=0
] c a 0 b d
in (e,) }
while 原语接受 5 个参数:c a 0 b d
,如下所示
0 个
cond_jaxpr
的常量(因为cond_nconsts
为 0)2 个
body_jaxpr
的常量(c
和a
)3 个携带初始值的参数
Scan循环#
JAX 支持一种特殊的数组元素循环形式(具有静态已知的形状)。迭代次数固定这一事实使得这种循环形式易于反向微分。此类循环使用 jax.lax.scan()
函数构建
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
这用 Haskell 类型签名 写成:C
是扫描携带的类型,A
是输入数组的元素类型,B
是输出数组的元素类型。
对于示例,请考虑以下函数 func11
>>> def func11(arr, extra):
... ones = jnp.ones(arr.shape) # A constant
... def body(carry, aelems):
... # carry: running dot-product of the two arrays
... # aelems: a pair with corresponding elements from the two arrays
... ae1, ae2 = aelems
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[] e:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
n:f32[] = add l m
in (n, g) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1
] b 0.0 a c
in (d, e) }
linear
参数描述了每个输入变量在主体中是否保证线性使用。一旦扫描完成线性化,更多参数将是线性的。
scan 原语接受 4 个参数:b 0.0 a c
,其中
一个是主体的自由变量
一个是携带的初始值
接下来的 2 个是扫描操作的数组。
XLA_call#
call 原语源于 JIT 编译,它封装了一个子 jaxpr 以及指定后端和计算应运行的设备的参数。例如
>>> from jax import jit
>>>
>>> def func12(arg):
... @jit
... def inner(x):
... return x + arg * jnp.ones(1) # Include a constant in the inner function
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0
c:f32[1] = pjit[
name=inner
jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
h:f32[1] = mul g f
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
j:f32[1] = add i h
in (j,) }
] a b
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
l:f32[1] = add k c
in (l,) }
XLA_pmap#
如果使用 jax.pmap()
变换,则将使用 xla_pmap
原语捕获要映射的函数。请考虑以下示例
>>> from jax import pmap
>>>
>>> def func13(arr, extra):
... def inner(x):
... # use a free variable "extra" and a constant jnp.ones(1)
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
c:f32[1,3] = xla_pmap[
axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[3] = add e f
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
i:f32[3] = add g h
j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
k:f32[3] = div i j
in (k,) }
devices=None
donated_invars=(False, False)
global_axis_size=1
in_axes=(None, 0)
is_explicit_global_axis_size=False
name=inner
out_axes=(0,)
] b a
in (c,) }
xla_pmap
原语指定轴的名称(参数 axis_name
)以及要映射的函数的主体作为 call_jaxpr
参数。此参数的值是一个具有 2 个输入变量的 Jaxpr。
参数 in_axes
指定了哪些输入变量应该被映射,哪些应该被广播。在我们的示例中,extra
的值被广播,而 arr
的值被映射。