使用 PyTrees#
JAX 内置支持类似于数组字典 (dict) 或列表的列表的字典,或者其他嵌套结构的 对象 - 在 JAX 中,这些被称为 PyTrees。本节将解释如何使用它们,提供有用的代码示例, 并指出常见的“注意事项”和模式。
什么是 PyTree?#
PyTree 是一个类似于容器的结构,由类似于容器的 Python 对象构成 - “叶子” PyTrees 或更多 PyTrees。PyTree 可以包含列表、元组和字典。叶子是任何不是 PyTree 的东西,例如数组,但单个叶子也是一个 PyTree。
在机器学习 (ML) 的上下文中,PyTree 可以包含
模型参数
数据集条目
强化学习代理观察
在处理数据集时,您经常会遇到 PyTrees(例如列表的列表的字典)。
下面是一个简单 PyTree 的示例。在 JAX 中,可以使用 jax.tree.leaves()
从树中提取扁平化的叶子,如 此处所示
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7f045679fb70>] has 3 leaves: [1, 'a', <object object at 0x7f045679fb70>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
任何由类似于容器的 Python 对象构成的树状结构都可以在 JAX 中被视为 PyTree。如果类 在 PyTree 注册表中,则被视为类似于容器的,默认情况下包括列表、元组和字典。任何类 型不在 PyTree 容器注册表中的对象将被视为树中的叶子节点。
可以通过注册类和指定如何展平树的函数来扩展 pytree 注册表以包含用户定义的容器类;请参见下面的 自定义 pytree 节点。
常见 pytree 函数#
JAX 提供了许多在 pytree 上操作的实用程序。这些可以在 jax.tree_util
子包中找到;为了方便起见,其中许多在 jax.tree
模块中具有别名。
常见函数:jax.tree.map
#
最常用的 pytree 函数是 jax.tree.map()
。它的工作原理类似于 Python 的原生 map
,但它透明地操作整个 pytree。
以下是一个例子
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map()
还允许在多个参数上映射 N 元 函数。例如
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
当使用 jax.tree.map()
的多个参数时,输入的结构必须完全匹配。也就是说,列表必须具有相同数量的元素,字典必须具有相同的键等。
使用 ML 模型参数的 jax.tree.map
示例#
此示例演示了在训练简单的 多层感知器 (MLP) 时,pytree 操作如何有用。
首先定义初始模型参数
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
使用 jax.tree.map()
检查初始参数的形状
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
接下来,定义用于训练 MLP 模型的函数
# Define the forward pass.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
# Define the loss function.
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
# Set the learning rate.
LEARNING_RATE = 0.0001
# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
# Calculate the gradients with `jax.grad`.
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
自定义 pytree 节点#
本节说明如何在 JAX 中使用 jax.tree_util.register_pytree_node()
和 jax.tree.map()
来扩展将被视为 pytree 中的内部节点(pytree 节点)的 Python 类型集。
为什么需要这样做?在前面的示例中,pytree 被显示为列表、元组和字典,其他所有内容都被视为 pytree 叶子。这是因为,如果您定义了自己的容器类,除非您将其注册到 JAX,否则它将被视为 pytree 叶子。即使您的容器类中包含树,情况也是如此。例如
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
jax.tree.leaves([
Special(0, 1),
Special(2, 4),
])
[<__main__.Special at 0x7f046c0ac430>, <__main__.Special at 0x7f046c0ae080>]
因此,如果您尝试使用 jax.tree.map()
期待叶子是容器内的元素,您将收到错误
jax.tree.map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4)
])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'
作为解决方案,JAX 允许通过类型的全局注册表来扩展被视为内部 pytree 节点的类型集。此外,已注册类型的值会递归遍历。
首先,使用 jax.tree_util.register_pytree_node()
注册新类型
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: The value of the registered type to flatten.
Returns:
A pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, for example, for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: The opaque data that was specified during flattening of the
current tree definition.
children: The unflattened children
Returns:
A reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # Instruct JAX what are the children nodes.
special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`.
)
现在您可以遍历特殊容器结构
jax.tree.map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]
现代 Python 配备了有助于更轻松地定义容器的工具。有些工具可以与 JAX 开箱即用,而其他工具则需要更多关注。
例如,Python NamedTuple
子类不需要注册即可被视为 pytree 节点类型
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]
请注意,name
字段现在显示为叶子,因为所有元组元素都是子元素。这就是您不必以硬编码方式注册类时发生的情况。
Pytree 和 JAX 变换#
许多 JAX 函数,如 jax.lax.scan()
,在数组的 pytree 上操作。此外,所有 JAX 函数变换都可以应用于接受数组 pytree 作为输入并产生数组 pytree 作为输出的函数。
某些 JAX 函数变换接受指定如何处理某些输入或输出值的可选参数(例如 in_axes
和 out_axes
参数到 jax.vmap()
)。这些参数也可以是 pytree,并且它们的结构必须与相应参数的 pytree 结构相对应。特别是,为了能够将这些参数 pytree 中的叶子“匹配”到参数 pytree 中的值,参数 pytree 通常被限制为参数 pytree 的树前缀。
例如,如果您将以下输入传递给 jax.vmap()
(请注意,函数的输入参数被认为是一个元组)
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
那么您可以使用以下 in_axes
pytree 来指定只有 k2
参数被映射(axis=0
),而其余参数没有被映射(axis=None
)
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
可选参数 pytree 结构必须与主输入 pytree 结构匹配。但是,可选参数可以选择性地指定为“前缀”pytree,这意味着单个叶子值可以应用于整个子 pytree。
例如,如果您具有与上述相同的 jax.vmap()
输入,但希望只映射字典参数,您可以使用
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果您希望每个参数都被映射,您可以编写一个应用于整个参数元组 pytree 的单个叶子值
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
这恰好是 jax.vmap()
的默认 in_axes
值。
相同的逻辑适用于其他可选参数,这些参数引用变换函数的特定输入或输出值,例如 jax.vmap()
中的 out_axes
。
显式键路径#
在 pytree 中,每个叶子都有一个键路径。叶子的键路径是一个 list
,包含键,其中列表的长度等于叶子在 pytree 中的深度。每个键都是一个 可哈希对象,它表示相应 pytree 节点类型的索引。键的类型取决于 pytree 节点类型;例如,dict
的键类型与 tuple
的键类型不同。
对于内置 pytree 节点类型,任何 pytree 节点实例的键集都是唯一的。对于包含此属性的节点的 pytree,每个叶子的键路径都是唯一的。
JAX 具有以下 jax.tree_util.*
方法,用于处理键路径
jax.tree_util.tree_flatten_with_path()
:工作原理类似于jax.tree.flatten()
,但会返回键路径。jax.tree_util.tree_map_with_path()
:工作原理类似于jax.tree.map()
,但该函数还会将键路径作为参数。jax.tree_util.keystr()
:给定一个通用键路径,返回一个易于阅读的字符串表达式。
例如,一个用例是打印与特定叶子值相关的调试信息
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo
为了表达键路径,JAX 为内置 pytree 节点类型提供了一些默认键类型,即
SequenceKey(idx: int)
:用于列表和元组。DictKey(key: Hashable)
:用于字典。GetAttrKey(name: str)
:用于namedtuple
以及最好是自定义 pytree 节点(下一节中将详细介绍)
您可以自由地为自定义节点定义自己的键类型。它们将与 jax.tree_util.keystr()
一起使用,只要它们的 __str__()
方法也使用易于阅读的表达式覆盖。
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
常见的 pytree 陷阱#
本节介绍使用 JAX pytree 时遇到的最常见问题(“陷阱”)。
将 pytree 节点误认为叶子#
需要注意的一个常见陷阱是意外地引入树节点而不是叶子
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
(Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
这里发生的事情是,数组的 shape
是一个元组,它是一个 pytree 节点,其元素作为叶子。因此,在映射中,它不是在例如 (2, 3)
上调用 jnp.ones
,而是在 2
和 3
上调用。
解决方案将取决于具体情况,但有两个广泛适用的选项
重写代码以避免中间
jax.tree.map()
。将元组转换为 NumPy 数组 (
np.array
) 或 JAX NumPy 数组 (jnp.array
),这使得整个序列成为叶子。
jax.tree_util
对 None
的处理 #
jax.tree_util
函数将 None
视为 pytree 节点的缺失,而不是叶子。
jax.tree.leaves([None, None, None])
[]
要将 None
视为叶子,可以使用 is_leaf
参数。
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
自定义 pytree 和带有意外值的初始化 #
另一个与用户定义的 pytree 对象相关的常见问题是,JAX 变换偶尔会使用意外的值初始化它们,因此在初始化时执行的任何输入验证都可能失败。例如
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
TypeError: Value '<object object at 0x7f0435724490>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4831: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
TypeError: Value '<object object at 0x7f04357248d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
在第一个使用
jax.vmap(...)(tree)
的情况下,JAX 的内部使用object()
值的数组来推断树的结构。在第二个使用
jax.jacobian(...)(tree)
的情况下,将树映射到树的函数的雅可比矩阵定义为树的树。
潜在解决方案 1
自定义 pytree 类的
__init__
和__new__
方法通常应避免执行任何数组转换或其他输入验证,否则应预期并处理这些特殊情况。例如
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
潜在解决方案 2
构建您的自定义
tree_unflatten
函数,使其避免调用__init__
。如果您选择此方法,请确保您的tree_unflatten
函数与__init__
保持同步,如果代码在更新时需要。
def tree_unflatten(aux_data, children):
del aux_data # Unused in this class.
obj = object.__new__(MyTree)
obj.a = a
return obj
常见的 pytree 模式 #
本节介绍了一些 JAX pytree 的最常见模式。
使用 jax.tree.map
和 jax.tree.transpose
转置 pytree #
为了转置一个 pytree(将树的列表转换为列表的树),JAX 有两个函数:{func} jax.tree.map
(更基础)和 jax.tree.transpose()
(更灵活、更复杂和更冗长)。
选项 1:使用 jax.tree.map()
。下面是一个例子
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
选项 2:对于更复杂的转置,使用 jax.tree.transpose()
,它更冗长,但允许您指定内部和外部 pytree 的结构,以获得更大的灵活性。例如
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}