jax.tree_util.register_pytree_with_keys#
- jax.tree_util.register_pytree_with_keys(nodetype, flatten_with_keys, unflatten_func, flatten_func=None)[源代码]#
扩展 pytree 中被视为内部节点的类型集合。
这是
register_pytree_node
的更强大的替代方案,它允许您在展平和树映射时访问每个 pytree 叶子的键路径。- 参数:
nodetype (type[T]) – 一个 Python 类型,将其视为内部 pytree 节点。
flatten_with_keys (Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]]) – 一个在扁平化期间使用的函数,它接收一个类型为
nodetype
的值,并返回一个元组,其中 (1) 包含每个键路径及其子元素的元组的可迭代对象,以及 (2) 一些可哈希的辅助数据,这些数据将存储在 treedef 中,并传递给unflatten_func
。unflatten_func (Callable[[_AuxData, Iterable[Any]], T]) – 一个接收两个参数的函数:由
flatten_func
返回并存储在 treedef 中的辅助数据,以及未扁平化的子元素。该函数应返回一个nodetype
的实例。flatten_func (None | Callable[[T], tuple[Iterable[Any], _AuxData]] | None) – 一个可选函数,类似于
flatten_with_keys
,但仅返回子元素和辅助数据。它必须以与flatten_with_keys
相同的顺序返回子元素,并返回相同的辅助数据。此参数是可选的,仅在调用像tree_map
和tree_flatten
这样不带键的函数时需要以加快遍历速度。
示例
首先,我们将定义一个自定义类型
>>> class MyContainer: ... def __init__(self, size): ... self.x = jnp.zeros(size) ... self.y = jnp.ones(size) ... self.size = size
现在,使用一个支持键的扁平化函数来注册它
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey >>> def flatten_with_keys(obj): ... children = [(GetAttrKey('x'), obj.x), ... (GetAttrKey('y'), obj.y)] # children must contain arrays & pytrees ... aux_data = (obj.size,) # aux_data must contain static, hashable data. ... return children, aux_data ... >>> def unflatten(aux_data, children): ... # Here we avoid `__init__` because it has extra logic we don't require: ... obj = object.__new__(MyContainer) ... obj.x, obj.y = children ... obj.size, = aux_data ... return obj ... >>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)
现在,这可以与像
tree_flatten_with_path()
这样的函数一起使用>>> m = MyContainer(4) >>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)