PyTrees#

什么是 PyTree?#

在 JAX 中,我们使用术语 pytree 来指代由类似容器的 Python 对象构建的树状结构。如果类在 pytree 注册表中,则被认为是类似容器的,默认情况下包括列表、元组和字典。也就是说

  1. 类型在 pytree 容器注册表中的任何对象都被认为是 pytree;

  2. 类型在 pytree 容器注册表中,并且包含 pytree 的任何对象都被认为是 pytree。

对于 pytree 容器注册表中的每个条目,都会注册一个类似容器的类型,该类型具有一对函数,用于指定如何将容器类型的实例转换为 (children, metadata) 对,以及如何将这种对转换回容器类型的实例。使用这些函数,JAX 可以将任何注册容器对象的树规范化为元组。

PyTree 示例

[1, "a", object()]  # 3 leaves

(1, (2, 3), ())  # 3 leaves

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

可以扩展 JAX 以将其他容器类型视为 pytree;请参阅下面的扩展 PyTree

PyTree 和 JAX 函数#

许多 JAX 函数,例如 jax.lax.scan(),都在数组的 pytree 上运行。JAX 函数转换可以应用于接受数组的 pytree 作为输入并产生数组的 pytree 作为输出的函数。

将可选参数应用于 PyTree#

某些 JAX 函数转换会采用可选参数,这些参数指定应如何处理某些输入或输出值(例如,vmap()in_axesout_axes 参数)。这些参数也可以是 pytree,它们的结构必须与相应参数的 pytree 结构相对应。特别是,为了能够将这些参数 pytree 中的叶子与参数 pytree 中的值“匹配起来”,通常会将参数 pytree 约束为参数 pytree 的树前缀。

例如,如果我们向 vmap() 传递以下输入(请注意,函数的输入参数被视为元组)

(a1, {"k1": a2, "k2": a3})

我们可以使用以下 in_axes pytree 来指定仅映射 k2 参数(axis=0),其余参数不映射(axis=None

(None, {"k1": None, "k2": 0})

可选参数 pytree 结构必须与主输入 pytree 的结构匹配。但是,可选参数可以选择指定为“前缀” pytree,这意味着单个叶值可以应用于整个子 pytree。例如,如果我们有与上面相同的 vmap() 输入,但希望仅映射字典参数,则可以使用

(None, 0)  # equivalent to (None, {"k1": 0, "k2": 0})

或者,如果我们希望映射每个参数,我们可以简单地编写一个应用于整个参数元组 pytree 的单个叶值

0

这恰好是 vmap() 的默认 in_axes 值!

相同的逻辑适用于引用转换函数的特定输入或输出值的其他可选参数,例如 vmapout_axes

查看对象的 PyTree 定义#

要查看任意 object 的 pytree 定义以进行调试,可以使用

from jax.tree_util import tree_structure
print(tree_structure(object))

开发者信息#

这主要是 JAX 内部文档,最终用户不应该需要理解它才能使用 JAX,除非在 JAX 中注册新的用户定义容器类型。其中一些细节可能会发生变化。

内部 PyTree 处理#

JAX 将 pytree 展平为 api.py 边界(以及控制流原语)的叶列表。这使得下游 JAX 内部结构更简单:像 grad()jit()vmap() 这样的转换可以处理接受并返回各种不同 Python 容器的用户函数,而系统的所有其他部分都可以对只接受(多个)数组参数并始终返回扁平数组列表的函数进行操作。

当 JAX 展平 pytree 时,它将生成一个叶列表和一个 treedef 对象,该对象编码原始值的结构。然后,可以使用 treedef 在转换叶子后构造匹配的结构化值。PyTree 是树状的,而不是 DAG 状或图状的,因为我们处理它们时假设引用透明性,并且它们不能包含引用循环。

这是一个简单的例子

from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")

# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]

默认情况下,pytree 容器可以是列表、元组、字典、namedtuple、None、OrderedDict。其他类型的值,包括数值和 ndarray 值,都被视为叶子

from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2),
    Point(1., 2.)
]
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

for structured in example_containers:
  show_example(structured)
structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0
structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None
structured=Array([0., 0.], dtype=float32)
  flat=[Array([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
  unflattened=Point(x=1.0, y=2.0)

扩展 PyTree#

默认情况下,未被识别为内部 pytree 节点(即类似容器)的结构化值的任何部分都被视为叶子

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)


show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Special(x=1.0, y=2.0)

被视为内部 pytree 节点的 Python 类型集是可扩展的,通过类型的全局注册表,并且递归地遍历注册类型的值。要注册新类型,可以使用 register_pytree_node()

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: the value of registered type to flatten.
  Returns:
    a pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, e.g., for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: the opaque data that was specified during flattening of the
      current treedef.
    children: the unflattened children

  Returns:
    a re-constructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0)

或者,您可以在类上定义适当的 tree_flattentree_unflatten 方法,并使用 register_pytree_node_class() 修饰它

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class RegisteredSpecial2(Special):
  def __repr__(self):
    return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0)

在定义 unflattening 函数时,通常 children 应包含数据结构的所有动态元素(数组、动态标量和 pytree),而 aux_data 应包含将滚入 treedef 结构中的所有静态元素。JAX 有时需要比较 treedef 的相等性,或计算其哈希值以用于 JIT 缓存,因此必须小心确保扁平化配方中指定的辅助数据支持有意义的哈希和相等性比较。

用于操作 pytree 的整套函数都在 jax.tree_util 中。

自定义 PyTree 和初始化#

用户定义的 PyTree 对象的一个常见陷阱是,JAX 转换偶尔会使用意外的值初始化它们,因此在初始化时进行的任何输入验证都可能会失败。例如

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to MyTree

在第一种情况下,JAX 的内部结构使用 object() 值的数组来推断树的结构;在第二种情况下,将树映射到树的函数的雅可比矩阵定义为树的树。

因此,自定义 PyTree 类的 __init____new__ 方法通常应避免进行任何数组转换或其他输入验证,否则应预料到并处理这些特殊情况。例如

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

另一种可能性是构造您的 tree_unflatten 函数,使其避免调用 __init__;例如

def tree_unflatten(aux_data, children):
  del aux_data  # unused in this class
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

如果您选择此路线,请确保您的 tree_unflatten 函数在代码更新时与 __init__ 保持同步。