Autodidax:从头开始构建 JAX 核心#
你是否曾经想学习 JAX 的工作原理,但发现其实现难以理解?现在你有机会了!通过阅读本教程,你将了解 JAX 核心系统中的所有重要概念。你甚至会了解我们奇怪的术语!
这是一个正在进行中的草稿。还有一些重要的成分缺失,将在第五部分和第六部分(以及更多?)中陆续添加。这里也有一些我们尚未应用于主系统的简化,但我们将来会应用。
第一部分:变换作为解释器:标准评估、jvp
和 vmap
#
我们想要变换看起来像这样的函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
将像 sin
这样的函数以及中缀运算符(mul
、add
和 neg
)背后的算术运算视为基本运算,这意味着它们是处理的原子单元,而不是组合。
“转换”意味着“以不同的方式解释”。与我们对数值输入应用原始操作以生成数值输出的标准解释不同,我们希望覆盖原始应用,并让不同的值流过我们的程序。例如,我们可能希望将每个原语的应用替换为应用它的 JVP 规则,并让原始-切线对流过我们的程序。此外,我们希望能够组合多个转换,从而形成解释器堆栈。
JAX 核心机制#
我们可以实现解释器堆栈,甚至可以在执行要转换的 Python 函数时动态地释放它们。首先,让我们定义这些原语,以便我们可以拦截它们的应用。
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
我们将在稍后设置数组数据类型和中缀运算符方法。
一个 Primitive
只是一个带有名称的对象,我们将在其中附加解释规则(每个转换一个)。bind
函数是我们的拦截点:它将根据参数在追踪器中的封装方式以及激活的解释器来确定要应用哪个转换规则。
用户代码调用的函数,例如 add
和 sin
,只是围绕 bind
调用的包装器。这些包装器使我们能够控制如何将参数传递给 bind
,特别是我们遵循一个方便的内部约定:当我们调用 bind
时,我们传递表示数组数据的数值作为位置参数,并将元数据(例如 axis
参数)传递给 sum_p
通过关键字。这种调用约定简化了一些核心逻辑(因为例如要定义的 Tracer
类的实例只能出现在 bind
的位置参数中)。包装器还可以提供文档字符串!
我们将活动解释器表示为一个堆栈。堆栈只是一个简单的 list
,每个元素都是一个容器,包含一个整数级别(对应于元素在堆栈中的高度)、一个解释器类型(我们称之为 trace_type
)和一个可选字段,用于解释器需要的任何全局数据。我们将每个元素称为 MainTrace
,尽管也许“Interpreter”更具描述性。
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
当我们要应用转换时,我们将使用 new_main
将另一个解释器推入堆栈。然后,当我们在函数中应用原语时,我们可以认为 bind
首先被堆栈顶部的追踪器(即具有最高级别的追踪器)解释。如果第一个解释器本身在其对原语的解释规则中绑定了其他原语,就像 sin_p
的 JVP 规则可能绑定 cos_p
和 mul_p
一样,那么这些 bind
调用将由下一级解释器处理。
解释器堆栈的底部是什么?在底部,我们知道所有转换解释器都已完成,我们只想进行标准评估。因此,在底部我们将放置一个评估解释器。
让我们勾勒出解释器的接口,它基于 Trace
和 Tracer
基类。一个 Tracer
代表一个封装的值,可能带有解释器使用的某些额外的上下文数据。一个 Trace
处理将值封装到 Tracers
中,还处理原语应用。
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
前两个方法是关于将值封装在 Tracer
中,这些对象是流经我们转换的 Python 程序的对象。最后一个方法是我们用来解释原语应用的回调。
Trace
本身不包含任何数据,除了对它相应的 MainTrace
实例的引用。事实上,在应用转换期间,可能会创建和丢弃多个 Trace
实例,而每个转换的应用只创建一个 MainTrace
实例。
至于 Tracer
本身,每一个都承载着一个抽象值(并将其中缀运算符转发给它),其余部分由转换决定。(Tracer
和 AbstractValue
之间的关系是,每个转换有一个 Tracer
,每个基本类型至少有一个 AbstractValue
,例如数组。)
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
请注意,我们实际上有两个 AbstractValue
用于数组,代表不同的抽象级别。一个 ShapedArray
代表所有具有给定形状和类型的可能的数组集。一个 ConcreteArray
代表一个包含单个数组值的单例集。
现在我们已经设置了解释器堆栈、解释器的 Trace/Tracer API 和抽象值,我们可以回到实现 bind
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
主要操作是,我们调用 find_top_trace
来确定哪个解释器应该处理此原语应用。然后我们调用顶级追踪器的 process_primitive
,以便追踪器可以应用其解释规则。对 full_raise
的调用只是确保输入被封装在顶级追踪器的 Tracer
实例中,对 full_lower
的调用是一个可选的优化,以便我们尽可能地将值从 Tracer
中解封。
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
换句话说,忽略 dynamic_trace
步骤到第 3 部分,find_top_trace
返回与输入上的 Tracer
关联的最高级别解释器,否则返回堆栈底部的解释器(这始终是评估追踪器,至少目前如此)。这与我们之前描述的,我们总是从运行堆栈顶部的解释器开始,然后向下工作,应用堆栈中的每个解释器有所不同。相反,我们只在原语绑定的输入参数被封装在一个与该解释器相对应的 Tracer
中时应用解释器。这种优化使我们能够跳过不相关的转换,但将一个假设烘焙到转换中,即转换主要遵循数据依赖关系(除了特殊的堆栈底部解释器,它解释一切)。
另一个方法是让堆栈中的每个解释器解释每个操作。这值得探索!JAX 在很大程度上是围绕数据依赖关系设计的,因为这对自动微分来说非常自然,而 JAX 的根源在于自动微分。但它可能过度拟合。
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
full_raise
中的逻辑用于将值封装到特定 Trace
的 Tracer
中,根据上下文在 Trace
上调用不同的方法:对非 Tracer
常量调用 Trace.pure
,而对已经是来自较低级别解释器的 Tracer
的值调用 Trace.lift
。这两种方法可以共享相同的实现,但是通过在核心逻辑中区分它们,我们可以为 Trace
子类提供更多信息。
这就是 JAX 核心的全部内容!现在我们可以开始添加解释器了。
评估解释器#
我们将从最简单的解释器开始:将位于解释器堆栈底部的评估解释器。
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
使用此解释器,我们可以评估用户函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
2.7177599838802657
哇!就像绕着一个大圈子走。但这种间接的目的在于,现在我们可以添加一些真正的转换。
使用 jvp
的前向模式自动微分#
首先,几个辅助函数
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def map(f, *xs):
return list(builtins.map(f, *xs))
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(builtins.zip(*args))
前向模式自动微分的 Tracer
承载着原始-切线对。 Trace
应用 JVP 规则。
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
请注意,pure
和 lift
都将值封装到一个具有最少上下文信息的 JVPTracer
中,即零切线值。
让我们为原语添加一些 JVP 规则
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
最后,我们添加一个转换 API 来启动追踪器
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
有了这个,我们可以微分了!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0
Pytrees 和扁平化用户函数的输入和输出#
jvp_v1
的一个限制是,它假设用户函数接受数组作为位置参数,并生成单个数组作为输出。如果它生成列表作为输出呢?或者接受嵌套容器作为输入呢?在堆栈的每一层处理所有可能的输入和输出容器将是一件很痛苦的事情。相反,我们可以包装用户函数,以便包装后的版本接受数组作为输入,并返回一个扁平的数组列表作为输出。包装器只需要解封其输入,调用用户函数,并扁平化输出。
以下是如何编写 jvp
,假设用户始终提供接受数组作为输入并生成扁平的数组列表作为输出的函数
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
为了支持在输入和输出中具有任意容器的用户函数,以下是如何编写面向用户的 jvp
包装器
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
请注意,我们必须将用户函数输出的树结构回传给 flatten_fun
的调用者。该信息直到我们实际运行用户函数才可用,因此 flatten_fun
只返回对可变单元格的引用,表示为一个 thunk。这些副作用是安全的,因为我们总是恰好运行一次用户函数。(这种安全的机制是 linear_util.py
中“线性”名称的原因,从线性类型的角度来看。)
剩下的就是编写 tree_flatten
、tree_unflatten
和 flatten_fun
。
显示代码单元格来源
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
store.set_value(out_tree)
return out_flat
return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
显示代码单元格来源
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
to_iterable: Callable
from_iterable: Callable
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
) -> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict,
lambda d: map(tuple, unzip2(sorted(d.items()))),
lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
node_metadata: Hashable
child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
else:
return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
if treedef is leaf:
return next(xs)
else:
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
使用这个基于 pytree 的 jvp
实现,我们现在可以处理任意的输入和输出容器。这在以后的转换中会很有用!
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}
使用 vmap
进行向量化批处理#
首先,我们提供两个辅助函数:一个用于从未映射的抽象值中生成映射抽象值(通过删除一个轴),另一个用于移动批处理维度。
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
用于向量化批处理的 Tracer
包含一个批处理值和一个可选整数,该整数指示哪个轴(如果有)是批处理轴。
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
在这里,我们实现了可选的 Tracer.full_lower
方法,该方法允许我们剥离一个批处理跟踪器(如果不需要,因为它不代表批处理值)。
对于 BatchTrace
,与 JVPTrace
类似,方法 pure
和 lift
只是将一个值打包到带有最小上下文信息的 BatchTracer
中,在本例中是带有哨兵值 not_mapped
的 batch_dim
。请注意,我们使用 MainTrace
的解释器全局数据字段来存储批处理轴大小。
接下来,我们可以为每个基元定义批处理解释器规则。
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
最后,我们添加一个转换 API 来启动追踪器
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
array([[ 1. , 0. , -0. ],
[ 0. , 0.54030231, -0. ],
[ 0. , 0. , -0.41614684]])
这就是 jvp
和 vmap
的全部内容!
第 2 部分:Jaxprs#
即将出现的下一个转换是用于即时编译的 jit
和用于反向模式自动微分的 vjp
。(grad
只是围绕 vjp
的一个小型包装器。)而 jvp
和 vmap
只需要每个 Tracer
传递少量额外上下文信息,对于 jit
和 vjp
,我们需要更丰富的上下文:我们需要表示程序。也就是说,我们需要 jaxprs!
Jaxprs 是 JAX 的程序内部中间表示。它们是显式类型化的、函数式的、一阶的,并且处于 ANF 形式。我们需要一个程序表示来实现 jit
,因为 jit
的目的是将计算从 Python 中分离出来。对于任何我们想要分离的计算,我们需要能够将它表示为数据,并在跟踪 Python 函数时构建它。同样,vjp
需要一种方法来表示反向模式自动微分反向传递的计算。我们对这两种需求都使用相同的 jaxpr 程序表示。
(构建程序表示是最 自由 的跟踪转换类型,因此除了处理本机 Python 控制流方面的问题外,任何转换都可以通过首先跟踪到 jaxpr,然后解释 jaxpr 来实现。)
Jaxpr 数据结构#
jaxpr 术语语法大致如下:
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
类型的语法如下:
jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...
我们如何用 Python 数据结构来表示它们?我们重用 ShapedArrays 来表示类型,并且我们可以使用一些 Python 结构体来表示术语语法。
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
对 jaxpr 进行类型检查涉及检查是否存在未绑定变量,变量是否只绑定一次,以及对于每个等式,基元应用的类型是否与输出绑定器的类型匹配。
class JaxprType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ', '.join(aval.str_short() for aval in self.in_types)
out_types = ', '.join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
env.add(v)
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
env.add(out_binder)
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
if isinstance(x, Var):
if x not in env: raise TypeError("unbound variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
else:
assert False
我们可以使用一个简单的解释器将 jaxpr 表示的函数应用于参数。
def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
env: dict[Var, Any] = {}
def read(x: Atom) -> Any:
return env[x] if type(x) is Var else x.val
def write(v: Var, val: Any) -> None:
assert v not in env # single-assignment
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params)
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
通过在解释器中使用 bind
,这个解释器本身是可以跟踪的。
使用跟踪构建 jaxprs#
现在我们已经将 jaxprs 作为一种数据结构,我们需要方法从跟踪 Python 代码中生成它们。通常,我们有两种跟踪到 jaxpr 的变体;jit
使用一种,而 vjp
使用另一种。我们将从 jit
使用的一种开始,该方法也用于控制流基元,如 lax.cond
、lax.while_loop
和 lax.scan
。
def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
def new_arg(self, aval: ShapedArray) -> JaxprTracer:
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
@property
def builder(self):
return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}
请注意,我们将一个 builder 对象保留为解释器全局数据,该对象在构建 jaxpr 时会跟踪变量、常量和等式。
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
tracer = JaxprTracer(trace, aval)
self.tracers.append(tracer)
return tracer
def add_eqn(self, eqn: JaxprEqn) -> None:
self.eqns.append(eqn)
def add_var(self, tracer: JaxprTracer) -> Var:
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer: JaxprTracer) -> Var:
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
) -> tuple[Jaxpr, list[Any]]:
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
我们需要为 JaxprTrace.process_primitive
制定的规则本质上是基元应用的类型规则:给定基元、其参数以及输入的类型,该规则必须生成输出的类型,然后将其与输出 JaxprTracer
打包在一起。我们可以将抽象评估规则用于相同目的,即使它们可能更通用(因为抽象评估规则必须接受 ConcreteArray 输入,并且因为它们只需要返回一组可能输出的上限,所以它们也可以生成 ConcreteArray 输出)。我们将这些抽象评估规则重复用于其他 jaxpr 生成跟踪机制,在那里额外的通用性很有用。
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
) -> list[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> list[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
为了检查我们对 jaxprs 的实现,我们可以添加一个 make_jaxpr
转换和一个美化打印器。
from functools import lru_cache
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
显示代码单元格来源
from collections import defaultdict
import string
class PPrint:
lines: list[tuple[int, str]]
def __init__(self, lines):
self.lines = lines
def indent(self, indent: int) -> 'PPrint':
return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])
def __add__(self, rhs: 'PPrint') -> 'PPrint':
return PPrint(self.lines + rhs.lines)
def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
if not rhs.lines: return self
if not self.lines: return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self) -> str:
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s: Any) -> PPrint:
return PPrint([(0, line) for line in str(s).splitlines()])
def vcat(ps: list[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
for v in jaxpr.outs)
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: defaultdict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: dict[str, Any]) -> PPrint:
items = sorted(params.items())
if items:
return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
else:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
let b:float64[] = mul 2.0 a
in ( b ) }
(float64[]) -> (float64[])
但这里有一个限制:由于 find_top_trace
的操作方式是通过数据依赖关系,因此 make_jaxpr_v1
无法分离由其给定的 Python 可调用对象执行的所有基元操作。例如:
jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let
in ( 4.0 ) }
这正是 全阶段化 所解决的问题。我们希望确保由 make_jaxpr
启动的 JaxprTrace
始终被应用,无论 bind
的任何输入是否都打包在相应的 JaxprTracer
实例中。我们可以通过使用第 1 部分中定义的 dynamic_trace
全局变量来实现这一点。
@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let a:float64[] = mul 2.0 2.0
in ( a ) }
以这种方式使用 dynamic_trace
在概念上等同于隐藏当前解释器堆栈,并在底部使用 JaxprTrace
启动一个新的堆栈。也就是说,低于 dynamic_trace
的任何解释器都不会被应用(因为 JaxprTrace.process_primitive
不会调用 bind
),尽管如果被跟踪到 jaxpr 的 Python 可调用对象本身使用转换,那么这些转换可以被推送到解释器堆栈上,位于 JaxprTrace
之上。但是,暂时隐藏解释器堆栈会破坏系统状态。dynamic_trace
标签实现了相同目标,同时使系统状态更简单。
这就是 jaxprs 的全部内容!有了 jaxprs,我们就可以实现 JAX 中剩余的主要功能。
第 3 部分:jit
,简化#
虽然 jit
在 API 中类似于转换(因为它接受 Python 可调用对象作为参数),但在幕后,它实际上是一个高阶基元,而不是转换。当基元由函数参数化时,它就是高阶的。
动态(“最终风格”)和分阶段(“初始风格”)处理#
处理高阶基元有两种选择。每种选择都需要不同的跟踪方法,并产生不同的权衡。
动态处理,其中
bind
接受一个 Python 可调用对象作为参数。我们将形成 jaxpr 的时间推迟到尽可能晚,即直到我们在解释器堆栈的底部运行最终解释器。这样,我们就可以在解释器堆栈的底部交换一个JaxprTrace
,从而分离所有基元操作,而不是执行它们。使用这种方法,转换会像往常一样在我们执行 Python 可调用对象时应用到堆栈中。这种方法可能非常难以实现,但它是尽可能通用的,因为它允许高阶基元不提高其参数的抽象级别,从而允许数据依赖的 Python 控制流。我们将这种方法称为使用“最终风格高阶基元”,它使用我们迄今为止使用的“最终风格转换”,即在跟踪时解除。分阶段处理,其中
bind
接受一个 jaxpr 作为参数。在我们调用bind
之前,在基元包装器中,我们可以直接使用make_jaxpr
来提前形成一个 jaxpr,并完全完成 Python 可调用对象。在这种情况下,make_jaxpr
将其JaxprTrace
放置在解释器堆栈的顶部,并且在堆栈中不会应用任何较低的转换(这些转换可能通过闭包的 Tracers 进入),因为我们在跟踪时会执行 Python 可调用对象。(在 Python 可调用对象内应用的转换会像往常一样应用,被添加到 JaxprTrace 上方的堆栈中。)相反,较低的转换将在稍后应用于调用基元,而调用基元的规则必须随后转换 jaxpr 本身。因为我们提前跟踪到 jaxpr,所以这种方法不支持数据依赖的 Python 控制流,但它更容易实现。我们将这种高阶基元称为“初始风格高阶基元”,并将 jaxpr 处理转换规则称为“初始风格转换规则”。
后一种方法适合 jit
,因为我们不需要在用户提供的 Python 可调用对象中支持数据依赖的 Python 控制流,因为 jit
的主要目的是将计算从 Python 中分离出来,以便由 XLA 执行。(相反,custom_jvp
是一个高阶基元,我们希望在其中支持数据依赖的 Python 控制流。)
从历史上看,我们在阅读完 typed tagless final interpreters 论文后,开始使用“初始风格”和“最终风格”术语,并开玩笑地将 JAX 称为“无类型标记最终解释器”的实现。我们并不声称要将这些术语的深层含义传承下来(或理解);我们松散地使用“初始风格”来表示“构建 AST 然后对其进行转换”,并使用“最终风格”来表示“在追踪时进行转换”。但这仅仅是不精确但又很流行的术语。
对于初始风格的方法,以下是面向用户的 jit
包装器
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
对于任何新的基本运算,都需要为其提供转换规则,从其评估规则开始。当我们评估 xla_call
基本运算的应用时,我们希望将计算阶段化到 XLA 中。这包括将 jaxpr 转换为 XLA HLO 程序,将参数值传输到 XLA 设备,执行 XLA 程序以及将结果传输回来。我们将缓存 XLA HLO 编译,以便对于每个 jit
过的函数,只需要针对每个参数形状和数据类型签名执行一次。
首先,一些实用程序。
class IDHashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self) -> int:
return id(self.val)
def __eq__(self, other):
return type(other) is IDHashable and id(self.val) == id(other.val)
接下来,我们将定义 xla_call
的评估规则
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDHashable, consts))
execute = xla_callable(IDHashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xc.XlaBuilder('xla_call')
xla_consts = _xla_consts(c, consts)
xla_params = _xla_params(c, in_avals)
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
out = xops.Tuple(c, outs)
compiled = xb.get_backend(None).compile(
xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:
unique_consts = {id(cnst): cnst for cnst in consts}
xla_consts = {
id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}
return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:
return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval: ShapedArray) -> xe.Shape:
return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
主要操作在 xla_callable
中,它使用 jaxpr_subcomp
将 jaxpr 编译为 XLA HLO 程序,然后返回一个可调用对象,该对象执行已编译的程序
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]
) -> list[xe.XlaOp]:
env: dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.inputs]
in_vals = map(read, eqn.inputs)
rule = xla_translations[eqn.primitive]
out_vals = rule(c, in_avals, in_vals, **eqn.params)
map(write, eqn.out_binders, out_vals)
return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
input_bufs = [input_handlers[type(x)](x) for x in args]
out_bufs = compiled.execute(input_bufs)
return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return np.asarray(buf)
xla_translations = {}
请注意,jaxpr_subcomp
具有一个简单的解释器的结构。这是一个常见的模式:我们处理 jaxpr 的方式通常是使用解释器。与任何解释器一样,我们需要为每个基本运算提供解释规则
def direct_translation(op, c, in_avals, in_vals):
del c, in_avals
return [op(*in_vals)]
xla_translations[add_p] = partial(direct_translation, xops.Add)
xla_translations[mul_p] = partial(direct_translation, xops.Mul)
xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
x, = in_vals
dims_complement = [i for i in range(len(shape)) if i not in axes]
return [xops.BroadcastInDim(x, shape, dims_complement)]
xla_translations[broadcast_p] = broadcast_translation
有了这些,我们现在可以使用 jit
来阶段化、编译和执行使用 XLA 的程序!
@jit
def f(x, y):
print('tracing!')
return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.])))
6.0
def f(x):
y = sin(x) * 2.
z = - y + x
return z
def deriv(f):
return lambda x: jvp(f, (x,), (1.,))[1]
print( deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344
与其实现 jit
以首先追踪到 jaxpr,然后将 jaxpr 降低到 XLA HLO,看起来我们可以跳过 jaxpr 步骤,并在追踪时直接降低到 HLO。也就是说,也许我们可以使用 Trace
和 Tracer
实现 jit
,它们在每个基本运算绑定时增量地附加到 XLA HLO 图。目前这是正确的,但在我们引入编译的 SPMD 计算时将不再可能,因为在编译程序之前,我们必须知道所需的副本数量。
我们还没有为 xla_call_p
定义任何转换规则,除了它的评估规则之外。也就是说,我们还不能执行 vmap
-of-jit
或 jvp
-of-jit
,甚至 jit
-of-jit
。相反,jit
必须位于“顶层”。让我们解决这个问题!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
else:
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
del num_consts # Only used at top-level.
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
subc = xc.XlaBuilder('inner xla_call')
xla_params = _xla_params(subc, in_avals)
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
subc = subc.build(xops.Tuple(subc, outs))
return destructure_tuple(c, xops.Call(c, subc, in_vals))
xla_translations[xla_call_p] = xla_call_translation
def destructure_tuple(c, tup):
num_elements = len(c.get_shape(tup).tuple_shapes())
return [xops.GetTupleElement(tup, i) for i in range(num_elements)]
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0. -0.68294197 0.18140515]
缺少的一块是数组的设备内存持久性。也就是说,我们已经定义了 handle_result
将结果传输回 CPU 内存作为 NumPy 数组,但通常最好避免传输结果只是为了在下一次操作时将它们传输回来。我们可以通过引入一个 Array
类来做到这一点,该类可以包装 XLA 缓冲区,并在其他情况下模拟 numpy.ndarray
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array)
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
显示代码单元格来源
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
第 4 部分:linearize
和 vjp
(以及 grad
!)#
linearize
和 vjp
自动微分函数构建在 jvp
上,但也涉及 jaxpr。这是因为两者都涉及到阶段化,或延迟,计算。
linearize
#
在 linearize
的情况下,我们希望阶段化 jvp
计算的线性部分。也就是说,根据 类似 Haskell 的类型签名,如果我们有 jvp : (a -> b) -> (a, T a) -> (b, T b)
,那么我们写 linearize : (a -> b) -> a -> (b, T a -o T b)
,使用 T a
来表示“a
的切线类型”,并使用“棒棒糖”-o
而不是箭头 ->
来表示一个线性函数。我们也根据 jvp
来定义 linearize
的语义
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
对于 (y, y_dot)
给出相同的结果,与
y, y_dot = jvp(f, (x,), (x_dot,))
相同,其中 f_lin
的应用不会重做任何线性化工作。我们将表示延迟的线性部分 f_lin : T a -o T b
作为 jaxpr。
顺便说一下,现在我们有了线性箭头 -o
,我们可以为 jvp
提供更具信息性的类型
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
这里我们写 UnrestrictedUse
只是为了表明我们有一个特殊的对,其中第一个元素可以在不受限制的(非线性)方式下使用。结合线性箭头,这个表示法只是为了表达函数 jvp f
以非线性方式使用其第一个输入,但以线性方式使用其第二个输入,产生相应的非线性输出(可以在非线性方式下使用)与线性输出配对。这个更精细的类型签名编码了 jvp f
中的数据依赖关系,这些依赖关系对于部分评估非常有用。
为了从 JVP 构建 f_lin
jaxpr,我们需要执行部分评估:我们在追踪时评估所有基本值,但将切线计算阶段化为 jaxpr。这是我们构建 jaxpr 的第二种方式。但是,make_jaxpr
及其底层的 JaxprTrace
/JaxprTracer
解释器旨在阶段化每个基本运算绑定,而这种第二种方法只阶段化那些与切线输入具有数据依赖关系的基本运算绑定。
首先,一些实用程序
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
接下来,我们将通过将 jvp
与一个一般的部分评估转换结合起来来编写 linearize
,该转换将在后面添加
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
现在我们转向一般的部分评估转换。目标是接受一个 Python 可调用对象和一个输入列表,其中一些已知,另一些未知,并生成 (1) 可以从已知输入计算的所有输出,以及 (2) 表示 Python 可调用对象计算部分的 jaxpr,该部分只有在剩余输入已知后才能执行。
这种转换很难用类型签名来概括。如果我们假设输入函数的类型签名是 (a1, a2) -> (b1, b2)
,其中 a1
和 a2
分别代表已知和未知的输入,以及 b1
仅对 a1
具有数据依赖关系,而 b2
对 a2
具有某种数据依赖关系,那么我们可以写
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
换句话说,给定类型为 a1
的输入的值,partial_eval
生成类型为 b1
的输出以及存在量化的类型为 r
的“残留”值,这些值代表在第二阶段完成计算所需的中间值。它还生成一个类型为 (r, a2) -> b2
的函数,该函数接受残留值以及剩余的输入并生成剩余的输出。
我们喜欢将部分评估视为将一个计算“解压缩”成两个。例如,考虑这个 jaxpr
{ lambda a:float64[] .
let b:float64[] = sin a
c:float64[] = neg b
in ( c ) }
JVP 的 jaxpr 看起来像这样
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
f:float64[] = neg c
g:float64[] = neg e
in ( f, g ) }
如果我们想象将部分评估应用于这个 jaxpr,其中第一个输入已知,而第二个未知,我们最终会将 JVP jaxpr “解压缩”成基本和切线 jaxpr
{ lambda a:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
f:float64[] = neg c
in ( f, d ) }
{ lambda d:float64[] b:float64[] .
let e:float64[] = mul d b
g:float64[] = neg e
in ( g ) }
这个第二个 jaxpr 代表了我们想要从 linearize
中获得的线性计算。
然而,与这个 jaxpr 示例不同,我们希望对已知值的计算在评估输入 Python 可调用对象时发生。也就是说,与其为整个函数 (a1, a2) -> (b1, b2)
形成 jaxpr,首先将所有操作阶段化到 Python 之外,然后再确定哪些可以立即评估,哪些必须延迟,我们只希望为那些由于依赖于未知输入而必须延迟的操作形成 jaxpr。在自动微分的背景下,这是最终使我们能够处理像 grad(lambda x: x**2 if x > 0 else 0.)
这样的函数的特性。Python 控制流之所以起作用,是因为部分评估将基本计算保留在 Python 中。因此,我们的 Trace
和 Tracer
子类必须动态地确定哪些可以评估,哪些必须阶段化为 jaxpr。
首先,我们从一个 PartialVal
类开始,它代表一个可以是已知或未知的值。
class PartialVal(NamedTuple):
aval: ShapedArray
const: Any | None
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
部分评估将接收一个表示输入的 PartialVal
列表,并返回一个表示输出的 PartialVal
列表以及一个代表延迟计算的 jaxpr。
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
接下来我们需要实现 PartialEvalTrace
及其 PartialEvalTracer
。这个解释器将在跟踪数据依赖关系的同时动态构建 jaxpr。为此,它构建了一个在 PartialEvalTracer
节点(代表分阶段输出的值)和 JaxprRecipe
节点(代表如何从其他值计算某些值的公式)之间进行的二部有向无环图(DAG)。一种食谱是 JaxprEqnRecipe
,对应于 JaxprEqn
的原始应用,但我们也有用于常量和 lambda 绑定器的食谱类型。
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
该 PartialEvalTrace
包含用于构建 JaxprRecipe
和 PartialEvalTracer
图表的逻辑。每个参数对应于一个 LambdaBindingRecipe
叶节点,每个常量都是一个 ConstRecipe
叶节点,其中包含对该常量的引用。所有其他跟踪器和食谱都来自 process_primitive
,它使用 JaxprEqnRecipe
形成跟踪器。
对于大多数原始类型,process_primitive
逻辑很简单:如果所有输入都已知,那么我们可以将原始类型绑定到已知的值(在 Python 中对其进行评估)并避免形成对应于输出的跟踪器。如果相反,任何输入都是未知的,那么我们改为分阶段输出到一个 JaxprEqnRecipe
,该 JaxprEqnRecipe
代表原始类型的应用。要构建代表未知输出的跟踪器,我们需要 aval,这些 aval 来自抽象评估规则。(请注意,跟踪器引用 JaxprEqnRecipe
,而 JaxprEqnRecipe
引用跟踪器;我们使用弱引用来避免循环垃圾回收。)
该 process_primitive
逻辑适用于大多数原始类型,但 xla_call_p
需要递归处理。因此,我们在 partial_eval_rules
字典中对它的规则进行了特殊处理。
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
现在我们可以使用 PartialEvalTrace
构建 jaxpr 的图形表示,我们需要一个机制将图形表示转换为标准 jaxpr。jaxpr 对应于图表的拓扑排序。
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
显示代码单元格来源
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
现在我们可以线性化了!
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454
为了处理 linearize
-jit
,我们仍然需要为 xla_call_p
编写部分评估规则。除了跟踪器簿记之外,主要任务是执行 jaxpr 的部分评估,将其“解压缩”为两个 jaxpr。
实际上需要编写两条规则:一条用于跟踪时间部分评估,我们将其称为 xla_call_partial_eval
,另一条用于 jaxpr 的部分评估,我们将其称为 xla_call_peval_eqn
。
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
else:
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
有了它,我们可以根据需要组合 linearize
和 jit
。
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
y = sin(x) * 2.
z = g(x, y)
return z
@jit
def g(x, y):
return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp
和 grad
#
该 vjp
变换的工作方式非常类似于线性化。它的类型签名类似于
linearize : (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, T b -o T a)
唯一的区别是我们在返回线性部分的计算之前将其转置,以便它从类型 T a -o T b
变为类型 T b -o T a
。也就是说,我们将实现 vjp
,本质上是
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
由于我们拥有线性计算作为 jaxpr,而不仅仅是一个 Python 可调用对象,因此我们可以将转置变换实现为 jaxpr 解释器。
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
register_pytree_node(UndefPrimal,
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
我们使用 UndefPrimal
实例来指示我们想要转置哪些参数。这些出现的原因是,通常,明确封闭值,我们希望将类型为 a -> b -o c
的函数转置为类型为 a -> c -o b
的函数。更一般地说,函数线性化的输入可以散布在参数列表中。因此,我们使用 UndefPrimal
来指示线性位置。我们将 UndefPrimal
注册为 pytree 节点,因为 pytree 机制提供了一种方便的方法来从参数列表中修剪掉这些占位符。
接下来,我们可以编写 eval_jaxpr_transposed
,以及对所有在至少一个参数中可以是线性的原始类型进行转置规则。
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
现在我们可以线性化和转置,我们终于可以编写 grad
了。
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
y = x * 2.
z = g(y)
return z
@jit
def g(x):
return cos(x) * 2.
print(grad(f)(3.))
1.1176619927957034
这是一个组合性压力测试。
# from core_test.py fun_with_nested_calls_2
def foo(x):
@jit
def bar(y):
def baz(w):
q = jit(lambda x: y)(x)
q = q + jit(lambda: y)()
q = q + jit(lambda y: w + y)(y)
q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return bar(x)
def assert_allclose(*vals):
for v1, v2 in zip(vals[:-1], vals[1:]):
np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
第 5 部分:控制流原始类型 cond
#
接下来,我们将为分阶段控制流添加高阶原始类型。这些类似于第 3 部分中的 jit
(另一个高阶原始类型),但不同之处在于它们由多个可调用对象而不是只有一个可调用对象参数化。
添加 cond
#
我们引入了一个 cond
原始类型来表示在 jaxpr 中对一个函数或另一个函数的条件应用。我们将 cond
的类型写为 Bool -> (a -> b) -> (a -> b) -> a -> b
。换句话说,cond
接收一个表示谓词的布尔值和两个类型相同的函数。根据谓词的值,它将一个函数或另一个函数应用于其最终参数。
在 Python 中,我们将它表示为一个函数,该函数本身接收两个函数作为参数。与 jit
一样,第一步是对其可调用参数调用 make_jaxpr
,将其转换为 jaxpr。
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
我们要求 true_jaxpr
和 false_jaxpr
具有相同的类型,但由于它们可能会关闭不同的常量(并且由于 jaxpr 只能表示闭合项,即不能具有自由变量,而是进行闭包转换),因此我们需要使用助手函数 _join_jaxpr_consts
来使两个 jaxpr 的输入绑定器列表保持一致。(为了更经济,我们可以尝试识别形状相同的常量对,但我们只是将常量列表连接起来。)
接下来,我们可以转向为 cond
添加解释器规则。它的评估规则很简单。
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3
对于它的 JVP 和 vmap 规则,我们只需要调用与 jit
创建的相同 jvp_jaxpr
和 vmap_jaxpr
实用程序,然后再次进行 _join_jaxpr_consts
。
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]
请注意,我们目前不支持谓词值本身是批处理的情况。在 JAX 主线中,我们通过将条件转换为 选择原始类型 来处理这种情况。只要 true_fun
和 false_fun
不涉及任何副作用原始类型,这种转换在语义上是正确的。
另一个未在此处表示,但在 JAX 主线中存在的事情是,将变换应用于两个类型相同的 jaxpr 可能会导致类型不同的 jaxpr。例如,将 JAX 主线版本的 vmap_jaxpr
应用于恒等函数 jaxpr
{ lambda a:float32[] .
let
in ( a ) }
将导致一个具有批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[10]]
(如果批处理大小为 10),而将其应用于零函数 jaxpr
{ lambda a:float32[] .
let
in ( 0. ) }
将导致一个具有非批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[]]
。这是一项优化,旨在避免不必要地对值进行批处理。但这意味着在 cond
中,我们需要一个额外的步骤来合并两个经过变换的 jaxpr 以具有一致的输出类型。我们在这里不需要此步骤,因为我们选择 vmap_jaxpr
始终在主轴上对所有输出进行批处理。
接下来,我们可以转向抽象评估和 XLA 降低规则。
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xc.XlaBuilder(name)
operand = xops.Parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)
return c.build(xops.Tuple(c, outs))
true_comp = make_comp('true_fn', true_jaxpr)
false_comp = make_comp('false_fn', false_jaxpr)
int_etype = xc.dtype_to_etype(np.dtype('int32'))
out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
[false_comp, true_comp], [operand] * 2)
return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2
最后,为了支持反向模式自动微分,我们需要部分评估和转置规则。对于部分评估,我们需要引入另一个 jaxpr 整理实用程序 _join_jaxpr_res
,以处理将部分评估应用于 true_fun
和 false_fun
通常会导致不同残差的事实。我们使用 _join_jaxpr_res
来使经过变换的 jaxpr 的输出类型保持一致(而 _join_jaxpr_consts
处理输入类型)。
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14
转置是 transpose_jaxpr
的一个相当简单的应用。
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
显示代码单元格来源
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond