Autodidax: 从零开始构建 JAX 核心#
是否曾想了解 JAX 的工作原理,但发现其实现令人费解? 那么你很幸运! 通过阅读本教程,你将了解 JAX 核心系统中的每一个重要概念。 你甚至会了解我们那些奇怪的术语!
这是一个正在进行中的草稿。 还有一些重要的组成部分缺失,将在第 5 部分和第 6 部分(以及更多部分?)中出现。 这里还有一些简化,我们尚未将其应用于主系统,但我们会这样做。
第一部分:作为解释器的转换:标准求值、jvp
和 vmap
#
我们想要转换如下所示的函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
将诸如 sin
之类的函数以及构成中缀运算符(mul
、add
和 neg
)的算术运算视为原始运算,这意味着处理的原子单元而不是组合。
“转换”意味着“以不同的方式解释”。 我们不想进行标准的解释(即将原始运算应用于数值输入以产生数值输出),而是要覆盖原始应用程序,并让不同的值流经我们的程序。 例如,我们可能希望用 其 JVP 规则的应用替换每个原始运算的应用,并让原始-切线对流经我们的程序。 此外,我们希望能够组合多个转换,从而形成解释器堆栈。
JAX 核心机制#
我们可以实现解释器堆栈,甚至可以在我们执行要转换的 Python 函数时让它们全部动态释放。 首先,让我们定义这些原语,以便我们可以拦截它们的应用程序
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
我们稍后将设置数组数据类型和中缀运算符方法。
Primitive
只是一个具有名称的对象,我们可以在该对象上附加解释规则(每个转换一个)。 bind
函数是我们的拦截点:它将根据参数在跟踪器中的封装方式以及哪些解释器处于活动状态来确定要应用哪个转换规则。
用户代码调用的函数(如 add
和 sin
)只是对 bind
调用的包装。 这些包装器允许我们控制如何将参数传递给 bind
,尤其是我们遵循一个方便的内部约定:当我们调用 bind
时,我们将表示数组数据的值作为位置参数传递,并通过关键字将诸如 axis
之类的元数据传递给 sum_p
。此调用约定简化了一些核心逻辑(例如,因为下面要定义的 Tracer
类的实例只能出现在 bind
的位置参数中)。 包装器还可以提供文档字符串!
我们将活动解释器表示为一个堆栈。 该堆栈只是一个简单的 list
,每个元素都是一个容器,其中包含一个整数级别(对应于元素在堆栈中的高度)、一个解释器类型(我们将其称为 trace_type
)以及解释器需要的任何全局数据的可选字段。 我们将每个元素称为 MainTrace
,尽管 “解释器” 可能更具描述性。
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
当我们即将应用转换时,我们将使用 new_main
将另一个解释器推入堆栈。 然后,当我们在函数中应用原语时,我们可以认为 bind
首先由堆栈顶部的跟踪(即具有最高级别的跟踪)解释。 如果第一个解释器本身在其原语的解释规则中绑定了其他原语,例如 sin_p
的 JVP 规则如何绑定 cos_p
和 mul_p
,那么这些 bind
调用将由下一级别的解释器处理。
解释器堆栈的底部是什么? 在底部,我们知道所有转换解释器都已完成,我们只想执行标准求值。 因此,在底部,我们将放置一个求值解释器。
让我们草拟解释器的接口,该接口基于 Trace
和 Tracer
基类。 Tracer
表示一个封装的值,可能携带解释器使用的一些额外上下文数据。 Trace
处理将值封装到 Tracer
中,并处理原始应用程序。
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
前两个方法是关于将值封装到 Tracer
中,这些 Tracer
是流经我们转换的 Python 程序的对象。 最后一个方法是我们用来解释原始应用程序的回调。
Trace
本身不包含任何数据,除了对其对应的 MainTrace
实例的引用。 实际上,在应用转换的过程中,可能会创建和丢弃 Trace
的多个实例,而每次应用转换只会创建一个 MainTrace
实例。
至于 Tracer
本身,每个都携带一个抽象值(并将中缀运算符转发给它),其余取决于转换。 (Tracer
和 AbstractValue
之间的关系是,每个转换都有一个 Tracer
,每个基本类型(例如数组)至少有一个 AbstractValue
。)
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
请注意,我们实际上为数组提供了两个 AbstractValue
,表示不同的抽象级别。ShapedArray
表示具有给定形状和 dtype 的所有可能数组的集合。ConcreteArray
表示由单个数组值组成的单例集合。
现在我们已经设置了解释器堆栈、用于解释器的 Trace/Tracer API 和抽象值,我们可以返回实现 bind
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
主要操作是我们调用 find_top_trace
来确定哪个解释器应该处理此原始应用程序。 然后,我们调用该顶级跟踪的 process_primitive
,以便该跟踪可以应用其解释规则。 对 full_raise
的调用只是确保输入封装在顶级跟踪的 Tracer
实例中,而对 full_lower
的调用是一个可选的优化,以便我们尽可能地将值从 Tracer
中解封。
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
换句话说,忽略第 3 部分之前的 dynamic_trace
步骤,find_top_trace
返回与其输入上的 Tracer
关联的最高级别解释器,否则返回堆栈底部的解释器(至少目前始终是求值跟踪)。 这与上述描述有所不同,在上述描述中,我们始终从运行堆栈顶部的解释器开始,然后向下工作,应用堆栈中的每个解释器。 相反,我们仅在原始绑定的输入参数封装在与该解释器对应的 Tracer
中时才应用解释器。 此优化允许我们跳过不相关的转换,但会内置一个假设,即转换主要遵循数据依赖关系(堆栈底部的特殊解释器除外,它会解释所有内容)。
另一种方法是让堆栈中的每个解释器解释每个操作。 这值得探索! JAX 的设计在很大程度上围绕数据依赖关系,因为这对于自动微分来说非常自然,并且 JAX 的根源在于自动微分。 但它可能过度拟合。
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
full_raise
中的逻辑用于将值封装到特定 Trace
的 Tracer
中,根据上下文在 Trace
上调用不同的方法:对非 Tracer
常量调用 Trace.pure
,对已经是来自较低级别解释器的 Tracer
的值调用 Trace.lift
。 这两种方法可以共享相同的实现,但是通过在核心逻辑中区分它们,我们可以向 Trace
子类提供更多信息。
这就是 JAX 核心的全部内容! 现在我们可以开始添加解释器了。
求值解释器#
我们将从最简单的解释器开始:位于解释器堆栈底部的求值解释器。
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
使用此解释器,我们可以求值用户函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
2.7177599838802657
哇! 就像绕了一个大圈子。 但是这种间接的目的在于,现在我们可以添加一些真正的转换。
使用 jvp
的前向模式自动微分#
首先,一些辅助函数
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def map(f, *xs):
return list(builtins.map(f, *xs))
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(builtins.zip(*args))
前向模式自动微分的 Tracer
携带一个原始值-切线值对。Trace
应用 JVP 规则。
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
请注意,pure
和 lift
都将值打包到具有最少上下文的 JVPTracer
中,这个上下文是一个零切线值。
让我们为原始操作添加一些 JVP 规则
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
最后,我们添加一个转换 API 来启动跟踪
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
有了这些,我们就可以进行微分了!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0
Pytrees 和展平用户函数的输入和输出#
jvp_v1
的一个限制是它假设用户函数接受数组作为位置参数,并生成一个数组作为输出。如果它产生一个列表作为输出怎么办?或者接受嵌套的容器作为输入怎么办?在堆栈的每一层处理输入和输出中所有可能的容器将是一件很痛苦的事情。相反,我们可以包装用户函数,使包装后的版本接受数组作为输入,并返回一个扁平的数组列表作为输出。包装器只需要解平其输入,调用用户函数,并展平输出。
假设用户始终为我们提供以数组作为输入并生成一个扁平的数组列表作为输出的函数,以下是我们希望如何编写 jvp
的方式
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
为了支持在输入和输出中具有任意容器的用户函数,以下是我们如何编写面向用户的 jvp
包装器
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
请注意,我们必须将用户函数输出的树结构传递回 flatten_fun
的调用者。在实际运行用户函数之前,该信息不可用,因此 flatten_fun
仅返回对可变单元格的引用,表示为 thunk。这些副作用是安全的,因为我们始终只运行用户函数一次。(这种安全机制是 linear_util.py
中“linear”名称的原因,在 线性类型 的意义上。)
剩下要做的就是编写 tree_flatten
、tree_unflatten
和 flatten_fun
。
显示代码单元格源代码
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
store.set_value(out_tree)
return out_flat
return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
显示代码单元格源代码
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
to_iterable: Callable
from_iterable: Callable
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
) -> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict,
lambda d: map(tuple, unzip2(sorted(d.items()))),
lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
node_metadata: Hashable
child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
else:
return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
if treedef is leaf:
return next(xs)
else:
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
有了这个处理 pytree 的 jvp
实现,我们现在可以处理任意的输入和输出容器。这在未来的转换中也会派上用场!
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}
使用 vmap
的矢量化批处理#
首先,是几个辅助函数,一个用于从非映射值生成映射的抽象值(通过删除一个轴),另一个用于移动批处理维度
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
用于矢量化批处理的 Tracer
携带一个批处理值和一个可选的整数,指示哪个轴(如果有)是批处理轴。
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
这里我们实现了可选的 Tracer.full_lower
方法,如果不需要批处理跟踪器,因为它不表示批处理值,我们可以使用该方法剥离批处理跟踪器。
对于 BatchTrace
,类似于 JVPTrace
,方法 pure
和 lift
只是将一个值框在具有最小上下文的 BatchTracer
中,在这种情况下,最小上下文是取哨兵值 not_mapped
的 batch_dim
。请注意,我们使用 MainTrace
的解释器全局数据字段来存储批处理轴大小。
接下来,我们可以为每个原始操作定义批处理解释器规则
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
最后,我们添加一个转换 API 来启动跟踪
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
array([[ 1. , 0. , -0. ],
[ 0. , 0.54030231, -0. ],
[ 0. , 0. , -0.41614684]])
这就是 jvp
和 vmap
的全部内容!
第 2 部分:Jaxprs#
接下来要进行的转换是用于即时编译的 jit
和用于反向模式自动微分的 vjp
。(grad
只是 vjp
的一个小包装器。)虽然 jvp
和 vmap
只需要每个 Tracer
携带一点额外的上下文,但对于 jit
和 vjp
,我们需要更丰富的上下文:我们需要表示程序。也就是说,我们需要 jaxpr!
Jaxpr 是 JAX 的内部程序中间表示。它们是显式类型化的、函数式的、一阶的和 ANF 形式的。我们需要一个用于 jit
的程序表示,因为 jit
的目的是将计算移出 Python。对于我们想要移出的任何计算,我们需要能够将其表示为数据,并在我们跟踪 Python 函数时构建它。类似地,vjp
需要一种方法来表示反向模式自动微分的反向传递的计算。我们使用相同的 jaxpr 程序表示来满足这两个需求。
(构建程序表示是最自由的跟踪转换,因此,除了处理原生 Python 控制流的问题之外,任何转换都可以通过首先跟踪到 jaxpr 然后解释 jaxpr 来实现。)
Jaxpr 数据结构#
jaxpr 术语语法大致如下
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
类型的语法是
jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...
我们如何将这些表示为 Python 数据结构?我们重用 ShapedArrays 来表示类型,并且可以使用一些 Python 结构来表示术语语法
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
类型检查 jaxpr 包括检查是否有未绑定的变量、变量是否只绑定一次,以及对于每个方程,原始应用程序的类型是否与输出绑定器的类型匹配。
class JaxprType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ', '.join(aval.str_short() for aval in self.in_types)
out_types = ', '.join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
env.add(v)
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
env.add(out_binder)
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
if isinstance(x, Var):
if x not in env: raise TypeError("unbound variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
else:
assert False
我们可以使用一个简单的解释器将 jaxpr 表示的函数应用于参数。
def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
env: dict[Var, Any] = {}
def read(x: Atom) -> Any:
return env[x] if type(x) is Var else x.val
def write(v: Var, val: Any) -> None:
assert v not in env # single-assignment
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params)
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
通过在解释器中使用 bind
,该解释器本身是可跟踪的。
使用跟踪构建 jaxpr#
现在我们有了作为数据结构的 jaxpr,我们需要一些方法来从跟踪 Python 代码生成这些 jaxpr。通常,我们跟踪到 jaxpr 的方式有两种变体;jit
使用一种,而 vjp
使用另一种。我们将从 jit
使用的那一种开始,这种方法也用于控制流原始操作,如 lax.cond
、lax.while_loop
和 lax.scan
。
def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
def new_arg(self, aval: ShapedArray) -> JaxprTracer:
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
@property
def builder(self):
return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}
请注意,我们将一个构建器对象作为解释器的全局数据保存,该对象在我们构建 jaxpr 时跟踪变量、常量和方程。
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
tracer = JaxprTracer(trace, aval)
self.tracers.append(tracer)
return tracer
def add_eqn(self, eqn: JaxprEqn) -> None:
self.eqns.append(eqn)
def add_var(self, tracer: JaxprTracer) -> Var:
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer: JaxprTracer) -> Var:
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
) -> tuple[Jaxpr, list[Any]]:
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
我们需要用于 JaxprTrace.process_primitive
的规则本质上是原始应用程序的类型规则:给定原始操作、其参数和输入的类型,规则必须为输出生成一个类型,然后将其与输出 JaxprTracer
一起打包。我们可以为此目的使用抽象评估规则,即使它们可能更通用(因为抽象评估规则必须接受 ConcreteArray 输入,并且由于它们只需要返回可能的输出集的上限,因此它们也可以生成 ConcreteArray 输出)。我们将重用这些抽象评估规则用于其他生成 jaxpr 的跟踪机制,其中潜在的额外通用性很有用。
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
) -> list[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> list[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
为了检查我们的 jaxpr 实现,我们可以添加一个 make_jaxpr
转换和一个漂亮打印机
from functools import lru_cache
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
显示代码单元格源代码
from collections import defaultdict
import string
class PPrint:
lines: list[tuple[int, str]]
def __init__(self, lines):
self.lines = lines
def indent(self, indent: int) -> 'PPrint':
return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])
def __add__(self, rhs: 'PPrint') -> 'PPrint':
return PPrint(self.lines + rhs.lines)
def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
if not rhs.lines: return self
if not self.lines: return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self) -> str:
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s: Any) -> PPrint:
return PPrint([(0, line) for line in str(s).splitlines()])
def vcat(ps: list[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
for v in jaxpr.outs)
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: defaultdict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: dict[str, Any]) -> PPrint:
items = sorted(params.items())
if items:
return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
else:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
let b:float64[] = mul 2.0 a
in ( b ) }
(float64[]) -> (float64[])
但这里有一个限制:由于 find_top_trace
通过数据依赖性进行操作的方式,make_jaxpr_v1
无法移出它给定的 Python 可调用对象执行的所有原始操作。例如
jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let
in ( 4.0 ) }
这正是 omnistaging 修复的问题。我们希望确保由 make_jaxpr
启动的 JaxprTrace
始终被应用,无论 bind
的任何输入是否都框在相应的 JaxprTracer
实例中。我们可以通过使用第 1 部分中定义的 dynamic_trace
全局变量来实现这一点
@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let a:float64[] = mul 2.0 2.0
in ( a ) }
以这种方式使用 dynamic_trace
在概念上与隐藏当前解释器堆栈并启动一个新的解释器堆栈(JaxprTrace
位于底部)相同。也就是说,不应用堆栈中低于 dynamic_trace
的解释器(因为 JaxprTrace.process_primitive
不调用 bind
),但如果被跟踪到 jaxpr 的 Python 可调用对象本身使用转换,则这些转换可以被推送到 JaxprTrace
之上的解释器堆栈中。但是,临时隐藏解释器堆栈会破坏系统状态。dynamic_trace
标记在保持系统状态更简单的同时实现了相同的目标。
这就是 jaxpr 的全部内容!有了 jaxpr,我们可以实现其余的主要 JAX 功能。
第 3 部分:jit
,简化#
虽然 jit
有一个类似于转换的 API,因为它接受一个 Python 可调用对象作为参数,但实际上它在底层是一个高阶原语,而不是一个转换。当一个原语由一个函数参数化时,它就是高阶的。
即时(“最终样式”)和分阶段(“初始样式”)处理#
对于如何处理高阶原语,有两种选择。每种方法都需要不同的追踪方法,并会产生不同的权衡。
即时处理,其中
bind
接受一个 Python 可调用对象作为参数。 我们会尽可能晚地推迟形成 jaxpr,即直到我们在解释器堆栈的底部运行最终解释器时才进行。这样,我们可以在解释器堆栈的底部交换一个JaxprTrace
,从而进行阶段输出而不是执行所有原始操作。使用这种方法,当我们像往常一样执行 Python 可调用对象时,堆栈中的转换将被应用。这种方法实现起来可能非常棘手,但它尽可能通用,因为它允许高阶原语不提高其参数的抽象级别,从而允许数据相关的 Python 控制流。我们将此方法称为使用“最终样式的更高阶原语”,它采用我们目前使用的在追踪时放电的“最终样式转换”。分阶段处理,其中
bind
接受一个 jaxpr 作为参数。 在我们调用bind
之前,在原语包装器中,我们可以只使用make_jaxpr
来预先形成一个 jaxpr,并完全处理掉 Python 可调用对象。在这种情况下,make_jaxpr
将其JaxprTrace
放在解释器堆栈的顶部,而堆栈中较低的任何转换(可能通过封闭的 Tracer 进入)都不会在我们追踪时应用于 Python 可调用对象。(在 Python 可调用对象内应用的转换会像往常一样被应用,并被添加到 JaxprTrace 上方的堆栈中。)相反,堆栈中较低的转换会稍后应用于调用原语,而调用原语的规则随后必须转换 jaxpr 本身。因为我们预先追踪到一个 jaxpr,所以这种方法不支持数据相关的 Python 控制流,但它更容易实现。我们将这种高阶原语称为“初始样式高阶原语”,并说它的 jaxpr 处理转换规则是“初始样式转换规则”。
后一种方法适用于 jit
,因为我们不需要在用户提供的 Python 可调用对象中支持数据相关的 Python 控制流,因为 jit
的全部目的是将计算从 Python 中分阶段输出,以便由 XLA 执行。(相比之下,custom_jvp
是一个高阶原语,我们希望在其中支持数据相关的 Python 控制流。)
从历史上看,我们在阅读了 类型化的无标签最终解释器论文之后,开始使用“初始样式”和“最终样式”术语,并开玩笑地将 JAX 称为“无类型有标签最终解释器”的实现。我们不声称理解这些术语背后的任何深刻含义;我们粗略地使用“初始样式”来表示“构建 AST 然后转换它”,而我们使用“最终样式”来表示“在追踪时转换”。但这只是一种不精确但又粘人的术语。
使用初始样式方法,这是用户面向的 jit
包装器
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
对于任何新的原语,我们需要为其提供转换规则,首先是其求值规则。当我们评估 xla_call
原语的应用时,我们希望将计算分阶段输出到 XLA。这涉及到将 jaxpr 翻译成 XLA HLO 程序,将参数值传输到 XLA 设备,执行 XLA 程序,然后将结果传输回来。我们将缓存 XLA HLO 编译,以便对于每个 jit
化的函数,每个参数形状和 dtype 签名只需要执行一次。
首先,一些实用程序。
class IDHashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self) -> int:
return id(self.val)
def __eq__(self, other):
return type(other) is IDHashable and id(self.val) == id(other.val)
接下来,我们将定义 xla_call
的求值规则
import io
from jax.extend.mlir import ir
from jax.extend.mlir.dialects import func
from jax.extend.mlir.dialects import stablehlo as hlo
from jax._src import xla_bridge as xb
class MlirContext(NamedTuple):
module: ir.Module
symbol_table: ir.SymbolTable
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDHashable, consts))
execute = xla_callable(IDHashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
with ir.Context() as ctx, ir.Location.unknown(ctx):
hlo.register_dialect(ctx)
m = ir.Module.create()
c = MlirContext(m, ir.SymbolTable(m.operation))
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def main(*params):
return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params)
output = io.StringIO()
c.module.operation.print(file=output)
compiled = xb.get_backend(None).compile(output.getvalue())
return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _mlir_dtype(dtype: np.dtype) -> ir.Type:
if np.issubdtype(dtype, np.signedinteger):
return ir.IntegerType.get_signless(np.iinfo(dtype).bits)
elif dtype == np.float32:
return ir.F32Type.get()
elif dtype == np.float64:
return ir.F64Type.get()
else:
raise NotImplementedError("MLIR conversion not implemented for ", dtype)
def aval_to_ir_type(aval: ShapedArray) -> ir.Type:
return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype))
def _hlo_const(x: Any) -> ir.Value:
a = np.asarray(x)
if a.dtype == np.bool_:
return hlo.constant(ir.DenseElementsAttr.get(
np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1),
shape=a.shape))
else:
return hlo.constant(ir.DenseElementsAttr.get(a))
def _hlo_consts(consts: list[Any]) -> list[ir.Value]:
unique_consts = {id(cnst): cnst for cnst in consts}
ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()}
return tuple(ir_consts[id(cnst)] for cnst in consts)
主要操作在 xla_callable
中,它使用 jaxpr_subcomp
将 jaxpr 编译成 XLA HLO 程序,然后返回一个可执行编译程序的函数
def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]:
env: dict[Var, ir.Value] = {}
def read(x: Atom) -> ir.Value:
return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val))
def write(v: Var, val: ir.Value) -> None:
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.inputs]
in_vals = map(read, eqn.inputs)
out_avals = [x.aval for x in eqn.out_binders]
rule = hlo_translations[eqn.primitive]
assert all(isinstance(v, ir.Value) for v in in_vals), in_vals
out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params)
assert all(isinstance(v, ir.Value) for v in out_vals), out_vals
map(write, eqn.out_binders, out_vals), out_vals
return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
input_bufs = [input_handlers[type(x)](x) for x in args]
out_bufs = compiled.execute(input_bufs)
return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return np.asarray(buf)
hlo_translations = {}
请注意,jaxpr_subcomp
具有简单解释器的结构。这是一种常见的模式:我们处理 jaxpr 的方式通常是使用解释器。与任何解释器一样,我们需要为每个原语提供解释规则
def direct_translation(op, c, in_avals, out_avals, in_vals):
del c, in_avals, out_avals
return [op(*in_vals)]
hlo_translations[add_p] = partial(direct_translation, hlo.add)
hlo_translations[mul_p] = partial(direct_translation, hlo.multiply)
hlo_translations[neg_p] = partial(direct_translation, hlo.negate)
hlo_translations[sin_p] = partial(direct_translation, hlo.sine)
hlo_translations[cos_p] = partial(direct_translation, hlo.cosine)
def compare_translation(op, c, in_avals, out_avals, in_vals):
del c, out_avals
return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))]
hlo_translations[greater_p] = partial(compare_translation, "GT")
hlo_translations[less_p] = partial(compare_translation, "LT")
def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis):
del c
(x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals
op = hlo.ReduceOp(
[aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))],
axis)
scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype))
reducer_region = op.body.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
hlo.return_([hlo.add(*reducer_region.arguments)])
return op.results
hlo_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes):
del c
(x,), (out_aval,) = in_vals, out_avals
dims_complement = [i for i in range(len(shape)) if i not in axes]
return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)]
hlo_translations[broadcast_p] = broadcast_translation
有了这些,我们现在可以使用 jit
来分阶段输出、编译和执行带有 XLA 的程序!
@jit
def f(x, y):
print('tracing!')
return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.])))
6.0
def f(x):
y = sin(x) * 2.
z = - y + x
return z
def deriv(f):
return lambda x: jvp(f, (x,), (1.,))[1]
print( deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344
与其实现 jit
先追踪到 jaxpr,然后再将 jaxpr 降级为 XLA HLO,似乎我们可以跳过 jaxpr 步骤,并在追踪时直接降级到 HLO。也就是说,也许我们可以使用一个 Trace
和 Tracer
来实现 jit
,它们在每个原语绑定时增量地附加到 XLA HLO 图上。对于现在来说,这是正确的,但是当我们引入编译的 SPMD 计算时,这将是不可能的,因为在那里我们必须在编译程序之前知道所需的副本数。
我们还没有为 xla_call_p
定义除其求值规则之外的任何转换规则。也就是说,我们还不能执行 vmap
-of-jit
或 jvp
-of-jit
甚至 jit
-of-jit
。相反,jit
必须在“顶层”。让我们修复它!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
else:
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
del num_consts, out_avals
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def inner_xla_call(*params):
return jaxpr_subcomp(c, jaxpr, params)
name = c.symbol_table.insert(inner_xla_call.func_op)
return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0. -0.68294197 0.18140515]
缺少的一项是数组的设备内存持久性。也就是说,我们已经定义了 handle_result
将结果作为 NumPy 数组传输回 CPU 内存,但是通常最好避免传输结果,只是为了在下一个操作中将其传输回来。我们可以通过引入一个 Array
类来实现,该类可以包装 XLA 缓冲区,否则可以鸭子类型为 numpy.ndarray
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array)
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
显示代码单元格源代码
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
第 4 部分:linearize
和 vjp
(和 grad
!)#
linearize
和 vjp
自动微分函数构建在 jvp
之上,但也涉及 jaxpr。这是因为两者都涉及阶段输出或延迟计算。
linearize
#
在 linearize
的情况下,我们希望分阶段输出 jvp
计算的线性部分。也就是说,用类似 Haskell 的类型签名来说,如果我们有 jvp : (a -> b) -> (a, T a) -> (b, T b)
,那么我们写 linearize : (a -> b) -> a -> (b, T a -o T b)
,使用 T a
来表示“a
的切线类型”,并使用“棒棒糖” -o
而不是箭头 ->
来表示线性函数。我们也根据 jvp
定义 linearize
的语义
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
对于 (y, y_dot)
给出相同的结果,如
y, y_dot = jvp(f, (x,), (x_dot,))
其中 f_lin
的应用不会重做任何线性化工作。我们将延迟的线性部分 f_lin : T a -o T b
表示为 jaxpr。
顺便说一句,既然我们有了线性箭头 -o
,我们可以为 jvp
提供一个稍微更有信息量的类型
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
在这里,我们编写 UnrestrictedUse
只是为了表明我们有一个特殊的对,其中第一个元素可以以不受限制(非线性)的方式使用。与线性箭头结合使用,此表示法仅旨在表达函数 jvp f
以非线性的方式使用其第一个输入,但以线性的方式使用其第二个输入,从而产生相应的非线性输出(可以以非线性的方式使用)和线性输出配对。这种更精细的类型签名编码了 jvp f
中的数据依赖关系,这对于部分评估很有用。
要从 JVP 构建 f_lin
jaxpr,我们需要执行部分评估:我们在跟踪时评估所有原始值,但将切线计算分阶段输出到 jaxpr 中。这是我们构建 jaxpr 的第二种方式。但是,make_jaxpr
及其基础的 JaxprTrace
/JaxprTracer
解释器的目标是分阶段输出每个原始绑定,而第二种方法仅分阶段输出那些对切线输入有数据依赖关系的原始绑定。
首先,一些实用程序
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
接下来,我们将通过组合 jvp
和一般的局部评估变换来编写 linearize
,接下来将添加
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
现在我们转向一般的局部评估变换。目标是接受一个 Python 可调用对象和一系列输入,其中一些是已知的,另一些是未知的,并生成 (1) 可以从已知输入计算出的所有输出,以及 (2) 表示 Python 可调用对象计算的一部分的 jaxpr,该部分只能在剩余输入已知后执行。
这种转换很难用类型签名来概括。如果我们假设输入函数的类型签名是 (a1, a2) -> (b1, b2)
,其中 a1
和 a2
分别表示已知输入和未知输入,并且 b1
仅依赖于 a1
的数据,而 b2
依赖于 a2
的数据,那么我们可以这样写:
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
换句话说,给定类型为 a1
的输入值,partial_eval
会生成类型为 b1
的输出,以及类型为存在量化类型 r
的“残余”值,这些值表示在第二阶段完成计算所需的中间值。它还会生成一个类型为 (r, a2) -> b2
的函数,该函数接受残余值和剩余输入,并生成剩余输出。
我们喜欢将部分求值视为将一个计算“解压缩”为两个计算。例如,考虑以下 jaxpr:
{ lambda a:float64[] .
let b:float64[] = sin a
c:float64[] = neg b
in ( c ) }
JVP 的 jaxpr 如下所示:
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
f:float64[] = neg c
g:float64[] = neg e
in ( f, g ) }
如果我们想象对这个 jaxpr 应用部分求值,其中第一个输入已知而第二个输入未知,那么我们最终会将 JVP jaxpr “解压缩”为原始 jaxpr 和切线 jaxpr:
{ lambda a:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
f:float64[] = neg c
in ( f, d ) }
{ lambda d:float64[] b:float64[] .
let e:float64[] = mul d b
g:float64[] = neg e
in ( g ) }
第二个 jaxpr 表示我们想要从 linearize
获得的线性计算。
然而,与这个 jaxpr 示例不同,我们希望在评估输入 Python 可调用对象时执行对已知值的计算。也就是说,我们不是为整个函数 (a1, a2) -> (b1, b2)
创建一个 jaxpr,先将所有操作移出 Python,然后再确定哪些操作现在可以评估,哪些操作必须延迟,而是只为那些必须延迟的操作创建一个 jaxpr,因为它们依赖于未知输入。在自动微分的上下文中,这是最终使我们能够处理诸如 grad(lambda x: x**2 if x > 0 else 0.)
之类函数的功能。Python 控制流之所以有效,是因为部分求值将原始计算保留在 Python 中。因此,我们的 Trace
和 Tracer
子类必须在运行时确定哪些可以评估,哪些必须被移出到一个 jaxpr 中。
首先,我们从 PartialVal
类开始,它表示一个可以是已知或未知的值:
class PartialVal(NamedTuple):
aval: ShapedArray
const: Any | None
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
部分求值将采用表示输入的 PartialVal
列表,并返回 PartialVal
输出列表以及表示延迟计算的 jaxpr。
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
接下来,我们需要实现 PartialEvalTrace
及其 PartialEvalTracer
。这个解释器会在跟踪数据依赖关系的同时动态构建 jaxpr。为此,它会在 PartialEvalTracer
节点(表示移出的值)和 JaxprRecipe
节点(表示如何根据其他值计算某些值的公式)之间构建一个二分有向无环图 (DAG)。一种配方是 JaxprEqnRecipe
,它对应于 JaxprEqn
的原始应用程序,但我们也有常量和 lambda 绑定器的配方类型。
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
PartialEvalTrace
包含用于构造 JaxprRecipe
和 PartialEvalTracer
图的逻辑。每个参数都对应于一个 LambdaBindingRecipe
叶节点,每个常量都是一个包含对常量引用的 ConstRecipe
叶节点。所有其他 tracer 和配方都来自 process_primitive
,它使用 JaxprEqnRecipe
创建 tracer。
对于大多数原语,process_primitive
逻辑很简单:如果所有输入都是已知的,那么我们可以在已知值上绑定原语(在 Python 中评估它),并避免形成对应于输出的 tracer。如果任何输入是未知的,那么我们将其移出到一个表示原语应用程序的 JaxprEqnRecipe
中。要构建表示未知输出的 tracer,我们需要 avals,这些 avals 来自抽象评估规则。(请注意,tracer 引用 JaxprEqnRecipe
,而 JaxprEqnRecipe
引用 tracer;我们通过使用 weakref
来避免循环垃圾。)
该 process_primitive
逻辑适用于大多数原语,但 xla_call_p
需要递归处理。因此,我们在 partial_eval_rules
字典中特殊处理它的规则。
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
现在我们可以使用 PartialEvalTrace
构建 jaxpr 的图表示,我们需要一种机制将图表示转换为标准的 jaxpr。jaxpr 对应于图的拓扑排序。
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
显示代码单元格源代码
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
现在我们可以进行线性化了!
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454
为了处理 linearize
-of-jit
,我们仍然需要为 xla_call_p
编写部分求值规则。除了 tracer 记账之外,主要任务是对 jaxpr 进行部分求值,将其“解压缩”为两个 jaxpr。
实际上需要编写两个规则:一个用于跟踪时间部分求值,我们称之为 xla_call_partial_eval
,另一个用于对 jaxpr 进行部分求值,我们称之为 xla_call_peval_eqn
。
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
else:
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
有了这些,我们可以随意组合 linearize
和 jit
:
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
y = sin(x) * 2.
z = g(x, y)
return z
@jit
def g(x, y):
return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp
和 grad
#
vjp
转换的工作方式很像线性化。它的类型签名是类似的:
linearize : (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, T b -o T a)
唯一的区别在于我们在返回线性计算之前对其进行转置,使其从类型 T a -o T b
变为类型 T b -o T a
。也就是说,我们将 vjp
实现为,本质上:
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
由于我们将线性计算作为 jaxpr 而不是仅仅作为 Python 可调用对象,因此我们可以将转置转换实现为 jaxpr 解释器。
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
register_pytree_node(UndefPrimal,
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
我们使用 UndefPrimal
实例来指示我们想要转置的参数。之所以会出现这些,是因为一般来说,为了明确关于闭包值的说明,我们希望将类型为 a -> b -o c
的函数转置为类型为 a -> c -o b
的函数。更一般地说,函数在其上呈线性的输入可以分散在参数列表中。因此,我们使用 UndefPrimal
来指示线性位置。我们将 UndefPrimal
注册为 pytree 节点,因为 pytree 机制提供了一种方便的方法从参数列表中修剪掉这些占位符。
接下来,我们可以编写 eval_jaxpr_transposed
,以及所有可以在至少一个参数中呈线性的原语的转置规则:
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
现在我们可以进行线性化和转置了,我们终于可以编写 grad
了:
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
y = x * 2.
z = g(y)
return z
@jit
def g(x):
return cos(x) * 2.
print(grad(f)(3.))
1.1176619927957034
这是一个关于组合性的压力测试:
# from core_test.py fun_with_nested_calls_2
def foo(x):
@jit
def bar(y):
def baz(w):
q = jit(lambda x: y)(x)
q = q + jit(lambda: y)()
q = q + jit(lambda y: w + y)(y)
q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return bar(x)
def assert_allclose(*vals):
for v1, v2 in zip(vals[:-1], vals[1:]):
np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
第 5 部分:控制流原语 cond
#
接下来,我们将为移出的控制流添加高阶原语。这些原语类似于第 3 部分中的 jit
,另一个高阶原语,但不同之处在于它们由多个可调用对象参数化,而不仅仅是一个。
添加 cond
#
我们引入了一个 cond
原语,用于表示在 jaxpr 内部有条件地应用一个函数或另一个函数。我们将 cond
的类型写为 Bool -> (a -> b) -> (a -> b) -> a -> b
。换句话说,cond
接收一个表示谓词的布尔值和两个类型相同的函数。根据谓词的值,它将其中一个函数应用于其最终参数。
在 Python 中,我们将其表示为一个函数,该函数本身接受两个函数作为参数。与 jit
一样,第一步是对其可调用参数调用 make_jaxpr
,将其转换为 jaxpr。
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
我们要求 true_jaxpr
和 false_jaxpr
具有相同的类型,但是由于它们可能会闭包不同的常量(并且由于 jaxpr 只能表示闭合项,即不能有自由变量,而是进行闭包转换),我们需要使用辅助函数 _join_jaxpr_consts
来使两个 jaxpr 的输入绑定器列表保持一致。(为了更经济,我们可以尝试识别具有相同形状的常量对,但我们只是简单地连接常量列表。)
接下来,我们可以开始为 cond
添加解释器规则。其求值规则很简单:
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3
对于其 JVP 和 vmap 规则,我们只需要调用为 jit
创建的相同 jvp_jaxpr
和 vmap_jaxpr
实用程序,然后再次传递 _join_jaxpr_consts
。
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]
请注意,我们目前不支持谓词值本身是批处理的情况。在主线 JAX 中,我们通过将条件转换为 select 原语来处理这种情况。只要 true_fun
和 false_fun
不涉及任何具有副作用的原语,该转换在语义上是正确的。
另一个未在此处表示但存在于主线 JAX 中的情况是,对两个类型相同的 jaxpr 应用变换可能会导致类型不同的 jaxpr。例如,将主线 JAX 版本的 vmap_jaxpr
应用于恒等函数 jaxpr
{ lambda a:float32[] .
let
in ( a ) }
如果批处理大小为 10,则会生成一个具有批处理输出的 jaxpr,其类型为 [float32[10]] -> [float32[10]]
,而将其应用于零函数 jaxpr
{ lambda a:float32[] .
let
in ( 0. ) }
则会生成一个具有未批处理输出的 jaxpr,其类型为 [float32[10]] -> [float32[]]
。这是一种优化,旨在避免不必要地批处理值。但这意味着在 cond
中,我们需要额外一步来连接两个转换后的 jaxpr,以使输出类型保持一致。我们不需要在这里执行此步骤,因为我们选择 vmap_jaxpr
始终在领先轴上批处理所有输出。
接下来,我们可以开始处理抽象求值和 XLA 降级规则。
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused
pred, *in_vals = in_vals
op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred)
with ir.InsertionPoint(op.true_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals))
with ir.InsertionPoint(op.false_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals))
return op.results
hlo_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2
最后,为了支持反向模式自动微分,我们需要部分求值和转置规则。对于部分求值,我们需要引入另一个 jaxpr 处理实用程序 _join_jaxpr_res
,来处理将部分求值应用于 true_fun
和 false_fun
通常会导致不同的残差这一事实。我们使用 _join_jaxpr_res
使转换后的 jaxpr 的输出类型保持一致(而 _join_jaxpr_consts
处理的是输入类型)。
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14
转置是 transpose_jaxpr
的一个相当直接的应用。
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
显示代码单元格源代码
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond