使用 PyTree#

JAX 内置支持类似于数组字典(dicts)、列表的列表的字典或其他嵌套结构的对象 —— 在 JAX 中,这些被称为 PyTree。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“陷阱”和模式。

什么是 PyTree?#

PyTree 是一个由类似容器的 Python 对象构建的类似容器的结构 —— “叶” PyTree 和/或更多 PyTree。PyTree 可以包含列表、元组和字典。叶是任何非 PyTree 的东西,例如数组,但单个叶也是一个 PyTree。

在机器学习(ML)的上下文中,PyTree 可以包含:

  • 模型参数

  • 数据集条目

  • 强化学习智能体观察

当处理数据集时,你经常会遇到 PyTree(例如列表的列表的字典)。

以下是一个简单的 PyTree 示例。在 JAX 中,你可以使用 jax.tree.leaves(),从树中提取扁平化的叶,如下所示

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7eff2b79bd50>]   has 3 leaves: [1, 'a', <object object at 0x7eff2b79bd50>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]

任何由类似容器的 Python 对象构建的类树结构都可以被视为 JAX 中的 PyTree。如果类在 PyTree 注册表中,则认为类是类似容器的,默认情况下包括列表、元组和字典。类型不在 PyTree 容器注册表中的任何对象都将被视为树中的叶节点。

PyTree 注册表可以扩展为包括用户定义的容器类,通过注册具有指定如何展平树的函数来实现;请参阅下面的 自定义 PyTree 节点

常用的 PyTree 函数#

JAX 提供了许多实用程序来对 PyTree 进行操作。这些可以在 jax.tree_util 子包中找到;为方便起见,其中许多在 jax.tree 模块中具有别名。

常用函数:jax.tree.map#

最常用的 PyTree 函数是 jax.tree.map()。它的工作方式类似于 Python 的原生 map,但可以透明地在整个 PyTree 上操作。

这是一个例子

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree.map() 还允许将 N 元 函数映射到多个参数上。例如

another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

当将多个参数与 jax.tree.map() 一起使用时,输入的结构必须完全匹配。也就是说,列表必须具有相同数量的元素,字典必须具有相同的键,等等。

使用 ML 模型参数的 jax.tree.map 示例#

此示例演示了在训练简单的 多层感知器(MLP)时,PyTree 操作如何有用。

首先定义初始模型参数

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

使用 jax.tree.map() 检查初始参数的形状

jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

接下来,定义训练 MLP 模型的功能

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

自定义 PyTree 节点#

本节解释了如何在 JAX 中通过使用 jax.tree_util.register_pytree_node()jax.tree.map() 来扩展 PyTree(PyTree 节点)中被认为是内部节点的 Python 类型集。

为什么你需要这个?在前面的示例中,PyTree 显示为列表、元组和字典,而其他所有内容都作为 PyTree 的叶。这是因为,如果你定义自己的容器类,它将被视为 PyTree 叶,除非你在 JAX 中注册它。即使你的容器类内部有树,情况也是如此。例如

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])
[<__main__.Special at 0x7eff480ed390>, <__main__.Special at 0x7eff480efca0>]

因此,如果你尝试使用 jax.tree.map(),期望叶是容器内的元素,你将收到错误

jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'

作为解决方案,JAX 允许通过类型的全局注册表扩展被视为内部 PyTree 节点的类型集。此外,已注册类型的值会被递归地遍历。

首先,使用 jax.tree_util.register_pytree_node() 注册新类型

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

现在你可以遍历特殊的容器结构

jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

现代 Python 配备了有用的工具,可以更轻松地定义容器。一些可以直接与 JAX 一起使用,但另一些则需要更多注意。

例如,Python 的 NamedTuple 子类不需要注册即可被视为 PyTree 节点类型

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

请注意,name 字段现在显示为叶,因为所有元组元素都是子元素。这就是当你不必以困难的方式注册类时发生的情况。

NamedTuple 子类不同,使用 @dataclass 修饰的类不会自动成为 PyTree。但是,可以使用 jax.tree_util.register_dataclass() 修饰器将其注册为 PyTree

from dataclasses import dataclass
import functools

@functools.partial(jax.tree_util.register_dataclass,
                   data_fields=['a', 'b', 'c'],
                   meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
  name: str
  a: Any
  b: Any
  c: Any

# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
  MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
  MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])
[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]

请注意,name 字段不显示为叶。这是因为我们将其包含在 jax.tree_util.register_dataclass()meta_fields 参数中,表明它应该被视为元数据/辅助数据,就像上面 RegisteredSpecial 中的 aux_data 一样。现在,MyDataclassContainer 的实例可以传递到 JIT 函数中,并且 name 将被视为静态(有关静态参数的更多信息,请参阅 将参数标记为静态

@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
  return x.a + x.b

# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)

MyOtherContainerNamedTuple 子类)进行比较。由于 name 字段是 PyTree 叶,因此 JIT 希望它可以转换为 jax.Array,并且以下操作会引发错误

moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)
TypeError: Error interpreting argument to <function f at 0x7eff0cdcc040> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

PyTree 和 JAX 转换#

许多 JAX 函数,例如 jax.lax.scan(),对数组的 PyTree 进行操作。此外,所有 JAX 函数转换都可以应用于接受数组的 PyTree 作为输入并产生数组的 PyTree 作为输出的函数。

一些 JAX 函数转换采用可选参数,这些参数指定应如何处理某些输入或输出值(例如 jax.vmap()in_axesout_axes 参数)。这些参数也可以是 PyTree,并且它们的结构必须与相应参数的 PyTree 结构相对应。特别是,为了能够将这些参数 PyTree 中的叶与参数 PyTree 中的值“匹配起来”,参数 PyTree 通常被限制为参数 PyTree 的树前缀。

例如,如果你将以下输入传递给 jax.vmap()(请注意,函数的输入参数被视为元组)

vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))

那么你可以使用以下 in_axes PyTree 来指定仅映射 k2 参数(axis=0),其余不映射(axis=None

vmap(f, in_axes=(None, {"k1": None, "k2": 0}))

可选参数 PyTree 结构必须与主输入 PyTree 的结构匹配。但是,可选参数可以选择指定为“前缀” PyTree,这意味着可以将单个叶值应用于整个子 PyTree。

例如,如果你具有与上面相同的 jax.vmap() 输入,但只想映射字典参数,则可以使用

vmap(f, in_axes=(None, 0))  # equivalent to (None, {"k1": 0, "k2": 0})

或者,如果你希望映射每个参数,则可以编写应用于整个参数元组 PyTree 的单个叶值

vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0})

这正好是 jax.vmap() 的默认 in_axes 值。

相同的逻辑适用于引用转换函数的特定输入或输出值的其他可选参数,例如 jax.vmap() 中的 out_axes

显式键路径#

在pytree中,每个叶子节点都有一个键路径。叶子节点的键路径是一个list,其中列表的长度等于叶子节点在pytree中的深度。每个都是一个可哈希对象,它表示对应pytree节点类型的索引。键的类型取决于pytree节点类型;例如,dict的键类型与tuple的键类型不同。

对于内置的pytree节点类型,任何pytree节点实例的键集合都是唯一的。对于包含具有此属性的节点的pytree,每个叶子节点的键路径都是唯一的。

JAX 提供了以下 jax.tree_util.* 方法来处理键路径

例如,一个用例是打印与特定叶子节点值相关的调试信息

import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo

为了表示键路径,JAX 为内置的 pytree 节点类型提供了一些默认的键类型,即:

  • SequenceKey(idx: int): 用于列表和元组。

  • DictKey(key: Hashable): 用于字典。

  • GetAttrKey(name: str): 用于 namedtuple,并且最好用于自定义的 pytree 节点(在下一节中详细介绍)

您可以为自定义节点自由定义自己的键类型。只要它们的 __str__() 方法也被易于阅读的表达式覆盖,它们就可以与 jax.tree_util.keystr() 一起使用。

for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))

常见的 pytree 陷阱#

本节涵盖了使用 JAX pytree 时遇到的一些最常见的问题(“陷阱”)。

将 pytree 节点误认为叶子节点#

需要注意的一个常见陷阱是意外引入了树节点而不是叶子节点

a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

这里发生的情况是,数组的 shape 是一个元组,它是一个 pytree 节点,其元素是叶子节点。因此,在映射中,不是在例如 (2, 3) 上调用 jnp.ones,而是在 23 上调用。

解决方案取决于具体情况,但有两种广泛适用的选择:

  • 重写代码以避免中间的 jax.tree.map()

  • 将元组转换为 NumPy 数组 (np.array) 或 JAX NumPy 数组 (jnp.array),这使得整个序列成为一个叶子节点。

jax.tree_utilNone 的处理#

jax.tree_util 函数将 None 视为缺少 pytree 节点,而不是叶子节点

jax.tree.leaves([None, None, None])
[]

要将 None 视为叶子节点,可以使用 is_leaf 参数

jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]

自定义 pytree 和使用意外值进行初始化#

用户定义的 pytree 对象的另一个常见陷阱是,JAX 转换有时会使用意外的值来初始化它们,因此在初始化时完成的任何输入验证都可能失败。例如:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.
TypeError: Value '<object object at 0x7eff10ae8640>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5819: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
TypeError: Value '<object object at 0x7eff10ae8b30>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
  • 在第一个使用 jax.vmap(...)(tree) 的情况下,JAX 的内部使用 object() 值的数组来推断树的结构

  • 在第二个使用 jax.jacobian(...)(tree) 的情况下,将树映射到树的函数的雅可比矩阵定义为树的树。

潜在的解决方案 1

  • 自定义 pytree 类的 __init____new__ 方法通常应避免进行任何数组转换或其他输入验证,或者预测并处理这些特殊情况。例如

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

潜在的解决方案 2

  • 构建自定义的 tree_unflatten 函数,使其避免调用 __init__。如果您选择此方法,请确保您的 tree_unflatten 函数在代码更新时与 __init__ 保持同步。示例

def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

常见的 pytree 模式#

本节介绍 JAX pytree 的一些最常见的模式。

使用 jax.tree.mapjax.tree.transpose 转置 pytree#

要转置 pytree (将树的列表转换为列表的树),JAX 有两个函数:{func} jax.tree.map (更基础) 和 jax.tree.transpose() (更灵活、复杂且冗长)。

选项 1: 使用 jax.tree.map()。这是一个示例

def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}

选项 2: 对于更复杂的转置,请使用 jax.tree.transpose(),它更冗长,但允许您指定内部和外部 pytree 的结构,以获得更大的灵活性。例如

jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}