JAX 原语的工作原理#
[email protected], 2019 年 10 月。
JAX 实现对 Python 函数的某些转换,例如 jit
、grad
、vmap
或 pmap
。要转换的 Python 函数必须是 JAX 可跟踪的,这意味着当 Python 函数执行时,它对数据的唯一操作要么是检查数据的属性,例如形状或类型,要么是特殊的称为 JAX 原语的操作。特别是,JAX 可跟踪函数有时会被 JAX 使用抽象参数调用。JAX 抽象值的一个例子是 ShapedArray(float32[2,2])
,它捕获了值的类型和形状,但没有捕获具体的数据值。JAX 原语知道如何对具体数据值和 JAX 抽象值进行操作。
JAX 转换后的函数本身必须是 JAX 可跟踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))
。
存在与大多数 XLA 操作相对应的预定义 JAX 原语,例如 add、matmul、sin、cos、索引。JAX 带有一个用 JAX 原语实现的 numpy 函数,这意味着使用 JAX 实现的 numpy 的 Python 程序是 JAX 可跟踪的,因此是可转换的。其他库可以通过用 JAX 原语实现来使其成为 JAX 可跟踪的。
JAX 原语集是可扩展的。无需用预定义的 JAX 原语重新实现函数,可以定义一个新的原语来封装函数的行为。
本文档的目的是解释 JAX 原语必须支持的接口,以便 JAX 可以执行其所有转换。
假设我们要在 JAX 中添加对三参数乘加函数的支持,该函数在数学上定义为“multiply_add(x, y, z) = x * y + z”。此函数作用于 3 个形状相同的浮点值张量,并对这些值逐点执行运算。
使用现有基元#
定义新函数的最简单方法是用 JAX 基元或用其他函数(本身用 JAX 基元编写,例如在 jax.lax
模块中定义的函数)来编写这些函数。
from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
"""Implementation of multiply-add using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0
grad(square_add_lax) = 4.0
为了了解 JAX 在内部如何使用基元,我们添加了一些用于跟踪函数调用的辅助函数。
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missing-docstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("|<- {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
我们可以使用其他已经用这些基元编写的函数(例如 jax.numpy
中的函数)来代替直接使用 jax.lax
基元。
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
|<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) = 4.0
请注意,在计算 grad
的过程中,JAX 使用特殊参数 ConcreteArray(...)
调用了 square_add_numpy
和 multiply_add_numpy
(将在本 Colab 中进一步介绍)。重要的是要记住,可跟踪 JAX 的函数不仅必须能够处理具体参数,还必须能够处理 JAX 可能用于抽象函数执行的特殊抽象参数。
只要函数是用 JAX 基元编写的,它就满足 JAX 可跟踪性属性。
定义新的 JAX 基元#
添加对乘加函数支持的正确方法是用现有 JAX 基元,如上所示。但是,为了演示 JAX 基元的工作原理,让我们假装我们要向 JAX 添加一个新的基元来实现乘加功能。
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A square-add function implemented using the new JAX-primitive."""
return multiply_add_prim(a, a, b)
如果我们尝试调用新定义的函数,我们会收到一个错误,因为我们还没有告诉 JAX 关于新基元语义的任何信息。
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/2844449444.py", line 2, in <module>
square_add_prim(2., 10.)
File "/tmp/ipykernel_1426/2656036843.py", line 48, in func_wrapper
res = func(*args)
File "/tmp/ipykernel_1426/1912233066.py", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
基本评估规则#
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0
JIT#
如果我们现在尝试使用 jit
,我们会收到一个 NotImplementedError
。
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 331, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
抽象评估规则#
为了将函数 JIT 化以及其他转换,JAX 首先使用参数的形状和类型对其进行抽象评估。这种抽象评估有多个目的。
获取计算中使用的 JAX 基元序列。此序列将被编译。
计算计算中使用的所有向量和运算的形状和类型。
例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])
或 ConcreteArray([1., 2., 3.])
。在后一种情况下,JAX 使用实际的具体值作为抽象值。
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
如果我们再次尝试进行 JIT 化,我们会看到抽象评估是如何进行的,但我们还会收到另一个错误,提示缺少实际的 XLA 编译规则。
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 331, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA 编译规则#
JAX 编译通过将每个基元编译成 XLA 运算的图来实现。
这是在 JAX 中添加新功能的最大障碍,因为 XLA 运算集是有限的,而且 JAX 已经为大多数运算预先定义了基元。但是,XLA 包含一个 CustomCall
运算,该运算可用于封装使用 C++ 定义的任意功能。
from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.ac.cn/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>
现在我们成功地进行了 JIT 化。请注意,在下面,JAX 首先对函数进行抽象评估,这将触发 multiply_add_abstract_eval
函数,然后编译它遇到的基元集,包括 multiply_add
。此时,JAX 调用 multiply_add_xla_translation
。
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bce506d0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc329d70>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc329070>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc3296f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc55d1b0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2aeaf970>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1426/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <lambda> at 0x7f62bc57b470, file "/tmp/ipykernel_1426/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1426/1570919344.py":1:0)), (<code object <module> at 0x7f62bc579f20, file "/tmp/ipykernel_1426/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1426/1570919344.py":1:0)), (<code object run_code at 0x7f62e37aa080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f62e37a9f20, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f62e37a9bb0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f62e36747c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/1570919344.py': '/tmp/ipykernel_1426/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc5a5b10>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc56b970>]
下面是另一个使用 jit
的示例,我们在其中仅相对于第一个参数进行编译。请注意,square_add_prim
的第二个参数是具体的,这会导致 multiply_add_abstract_eval
的第三个参数为 ConcreteArray
。我们可以看到,multiply_add_abstract_eval
可以与 ShapedArray
和 ConcreteArray
结合使用。
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc30be20>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc337630>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc335e30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc336a30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc5a7d00>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2aeaf970>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1192749400.py":1:0) at callsite("<module>"("/tmp/ipykernel_1426/1192749400.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <lambda> at 0x7f62d9d7b5d0, file "/tmp/ipykernel_1426/1192749400.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1426/1192749400.py":1:0)), (<code object <module> at 0x7f62d9d7b260, file "/tmp/ipykernel_1426/1192749400.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1426/1192749400.py":1:0)), (<code object run_code at 0x7f62e37aa080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f62e37a9f20, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f62e37a9bb0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f62e36747c0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/1192749400.py': '/tmp/ipykernel_1426/1192749400.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/1192749400.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc3441f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc33d4b0>]
前向微分#
JAX 以雅可比向量积的形式实现前向微分(请参见 JAX 自动微分手册)。
如果我们现在尝试计算 jvp
函数,我们会收到一个错误,因为我们还没有告诉 JAX 如何微分 multiply_add
基元。
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/800067577.py", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1681, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1710, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
|<- multiply_add_impl = 3.0
|<- multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
|<- multiply_add_impl = 5.0
|<- multiply_add_prim = 5.0
|<- multiply_add_value_and_jvp = (14.0, 5.0)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
要解释
为什么 JAX 在
square_add_prim
中使用 ConcreteArray?这里没有进行抽象评估。不确定如何解释
multiply_add_prim
是如何使用 ConcreteValue 调用的,而我们没有调用multiply_add_abstract_eval
。我认为在这里展示 jaxpr 会很有帮助。
前向微分的 JIT#
我们可以将 JIT 应用于前向微分函数。
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc58a5c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc387430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3874b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc387470>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc344940>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2a958870>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <lambda> at 0x7f62bc579630, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0)), (<code object <module> at 0x7f62bc57b890, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1426/1223862052.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/454004196.py': '/tmp/ipykernel_1426/454004196.py', '/tmp/ipykernel_1426/1223862052.py': '/tmp/ipykernel_1426/1223862052.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/454004196.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1426/1223862052.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc55da20>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc3872b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc58a5c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc387430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3874b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc387470>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc344940>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2a958870>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x556b2b04d5e0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <lambda> at 0x7f62bc579630, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0)), (<code object <module> at 0x7f62bc57b890, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1426/1223862052.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/454004196.py': '/tmp/ipykernel_1426/454004196.py', '/tmp/ipykernel_1426/1223862052.py': '/tmp/ipykernel_1426/1223862052.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/454004196.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1426/1223862052.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc344e50>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc390c30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc58a5c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc387430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3874b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc387470>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc344940>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2a958870>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x556b2b04d5e0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x556b2b0548b0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0) at "<module>"("/tmp/ipykernel_1426/1223862052.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":27:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <lambda> at 0x7f62bc579630, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1426/1223862052.py":2:0)), (<code object <module> at 0x7f62bc57b890, file "/tmp/ipykernel_1426/1223862052.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1426/1223862052.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/454004196.py': '/tmp/ipykernel_1426/454004196.py', '/tmp/ipykernel_1426/1223862052.py': '/tmp/ipykernel_1426/1223862052.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/454004196.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1426/1223862052.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc5a7d30>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62d8021970>]
请注意,首先我们对 multiply_add_value_and_jvp
进行抽象评估,这反过来又会对基本和切线评估进行抽象评估(总共调用了 3 次 ma
基元)。然后,我们编译基元的 3 次出现。
反向微分#
如果我们现在尝试使用反向微分,我们会看到 JAX 首先使用 multiply_add_value_and_jvp
为抽象值计算前向微分,但随后遇到 NotImplementedError
。
在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp
进行抽象评估,以获得计算输出切线的基元跟踪。观察到,JAX 对微分点的具体值和切线的抽象值执行此抽象评估。还要观察到,JAX 使用特殊的抽象切线值 Zero
来表示对应于 ma
的第 3 个参数的切线。这反映了我们没有相对于 square_add_prim
的第 2 个参数进行微分,该参数会流入到 multiply_add_prim
的第 3 个参数。
还要观察到,在切线的抽象评估过程中,我们将值 0.0 作为切线传递给第 3 个参数。这是由于在 multiply_add_value_and_jvp
的定义中使用了 make_zero
函数。
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 283, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/339076514.py", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 392, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
上述错误是因为缺少一块,导致 JAX 无法使用前向微分代码来计算反向微分。
转置#
如上所述,在计算反向微分时,JAX 会获得一个基元跟踪,该跟踪使用前向微分来计算切线。然后,**JAX 反向解释此跟踪**,并对每个基元应用一个**转置**规则。
为了了解发生了什么,现在考虑一个更简单的函数“f(x, y) = x * y + y”的示例。假设我们需要在点 (2., 4.)
处进行微分。JAX 将生成以下 JVP 切线计算,从输入的切线 xt
和 yt
计算 ft
。
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
通过构造,切线计算始终是输入切线的线性函数。在切线计算中可能出现的唯一非线性运算符是乘法,但其中一个操作数是常数。
JAX 将通过反向处理 JVP 计算来生成反向微分计算。对于切线计算中的每个运算,它使用运算结果的切线来累积所用变量的切线。
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
可以验证此计算产生 xct = 4.
和 yct = 3.
,它们是函数 f
的偏导数。
JAX 知道在 JVP 计算中可能出现的每个基元的转置方法。从概念上讲,如果基元 p(x, y, z)
对 x
的常数值而言是参数 y
和 z
的线性函数,例如 p(x, y, z) = y*cy + z*cz
,那么基元的转置为
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
请注意,p_transpose
接受基元输出的切线和对应于基元每个参数的值。对于线性参数,转置获得未定义的 _
值,而对于其他参数,它获得实际的常数。转置为基元的每个参数返回一个切线值,其中为常数参数返回的值为 None
。
具体来说,
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments is a constant.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
现在我们可以完成 grad
的运行。
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)
请注意对 multiply_add_transpose
的两次调用。它们对应于在 multiply_add_value_and_jvp
中计算 output_tangent
时对 multiply_add_prim
的两次使用。对转置的第一次调用对应于 multiply_add_prim
的最后一次使用:multiply_add_prim(xt, y, ...)
,其中 y
是常数 2.0。
反向微分的 JIT#
请注意,multiply_add_value_and_jvp
的抽象评估仅使用抽象值,而在没有 JIT 的情况下,我们使用的是 ConcreteArray
。
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc399800>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc3a78f0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3a53b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc3a7bf0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc347700>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2b436bf0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<module>"("/tmp/ipykernel_1426/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <module> at 0x7f62bc57bc00, file "/tmp/ipykernel_1426/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1426/3085343041.py":1:0)), (<code object run_code at 0x7f62e37aa080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/454004196.py': '/tmp/ipykernel_1426/454004196.py', '/tmp/ipykernel_1426/3085343041.py': '/tmp/ipykernel_1426/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/454004196.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1426/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62d9d5ad40>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc59a430>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc399800>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc3a78f0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3a53b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc3a7bf0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc347700>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b2b436bf0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<module>"("/tmp/ipykernel_1426/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x556b2aff4e60>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<module>"("/tmp/ipykernel_1426/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <module> at 0x7f62bc57bc00, file "/tmp/ipykernel_1426/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1426/3085343041.py":1:0)), (<code object run_code at 0x7f62e37aa080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7f62bc57ae40, file "/tmp/ipykernel_1426/454004196.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1426/454004196.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/454004196.py': '/tmp/ipykernel_1426/454004196.py', '/tmp/ipykernel_1426/3085343041.py': '/tmp/ipykernel_1426/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/454004196.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1426/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62d9d5b550>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc3bccb0>]
批处理#
批处理转换将逐点计算转换为对向量的计算。如果我们现在尝试,我们会得到一个 NotImplementedError
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1426/2641678767.py", line 3, in <module>
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 986, in vmap_f
out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented
我们需要告诉 JAX 如何评估原语的批处理版本。在这种情况下,multiply_add_prim
已经对任何维度的输入向量进行逐点运算。因此,批处理版本可以使用相同的 multiply_add_prim
实现。
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
|<- multiply_add_impl = [14. 29.]
|<- multiply_add_prim = [14. 29.]
|<- multiply_add_batch = ([14. 29.], 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
批处理的 JIT#
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
|<- multiply_add_abstract_eval = ShapedArray(float32[2])
|<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f62bc39a1b0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f62bc390230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f62bc3918f0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f62bc3918b0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f62bcf92a40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f62bc346080>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x556b29b0c080>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1426/2356588168.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0) at callsite("<module>"("/tmp/ipykernel_1426/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f62bc542970, file "/tmp/ipykernel_1426/1912233066.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1426/1912233066.py":11:0)), (<code object func_wrapper at 0x7f62bce8e550, file "/tmp/ipykernel_1426/2656036843.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1426/2656036843.py":48:0)), (<code object multiply_add_batch at 0x7f62bc342600, file "/tmp/ipykernel_1426/2356588168.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1426/2356588168.py":25:0)), (<code object square_add_prim at 0x7f62bc543520, file "/tmp/ipykernel_1426/1912233066.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1426/1912233066.py":16:0)), (<code object <module> at 0x7f62bc57b940, file "/tmp/ipykernel_1426/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1426/1392464762.py":1:0)), (<code object run_code at 0x7f62e37aa080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1426/1912233066.py': '/tmp/ipykernel_1426/1912233066.py', '/tmp/ipykernel_1426/2656036843.py': '/tmp/ipykernel_1426/2656036843.py', '/tmp/ipykernel_1426/2356588168.py': '/tmp/ipykernel_1426/2356588168.py', '/tmp/ipykernel_1426/1392464762.py': '/tmp/ipykernel_1426/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1426/1912233066.py': True, '/tmp/ipykernel_1426/2656036843.py': True, '/tmp/ipykernel_1426/2356588168.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1426/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f62bc3cc310>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f62bc3bcef0>]