jax.tree_util 模块

jax.tree_util 模块#

用于处理树状容器数据结构的实用程序。

此模块提供了一组用于处理树状数据结构的实用程序函数,例如嵌套元组、列表和字典。我们将这些结构称为 pytree。它们是树,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,任何 pytree 的 pytree 都是 pytree)并且可以递归地进行操作(对象标识等价性不会通过映射操作保留,并且结构不能包含引用循环)。

被认为是 pytree 节点(例如,可以映射,而不是被视为叶子)的 Python 类型集是可扩展的。有一个模块级类型注册表,并且类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上对本文件中的实用程序函数变得透明。

此模块的主要目的是启用用户定义数据结构与 JAX 变换(例如,jit)之间的互操作性。这不是一个通用的树状数据结构处理库。

查看 JAX pytree 说明 获取示例。

函数列表#

Partial(func, *args, **kw)

一个在 pytree 中工作的 functools.partial 版本。

all_leaves(iterable[, is_leaf])

测试给定可迭代对象中的所有元素是否都是叶子。

build_tree(treedef, xs)

从嵌套的可迭代结构构建 treedef

register_dataclass(nodetype, data_fields, ...)

扩展被认为是 pytree 中内部节点的类型集。

register_pytree_node(nodetype, flatten_func, ...)

扩展被认为是 pytree 中内部节点的类型集。

register_pytree_node_class(cls)

扩展被认为是 pytree 中内部节点的类型集。

register_pytree_with_keys(nodetype, ...[, ...])

扩展被认为是 pytree 中内部节点的类型集。

register_pytree_with_keys_class(cls)

扩展被认为是 pytree 中内部节点的类型集。

register_static(cls)

cls 注册为没有叶子的 pytree。

tree_flatten_with_path(tree[, is_leaf])

tree_flatten 一样展平 pytree,但也会返回每个叶子的键路径。

tree_leaves_with_path(tree[, is_leaf])

tree_leaves 一样获取 pytree 的叶子,并返回每个叶子的键路径。

tree_map_with_path(f, tree, *rest[, is_leaf])

将多输入函数映射到 pytree 键路径和参数上,以生成新的 pytree。

treedef_children(treedef)

返回直接子节点的 treedef 列表。

treedef_is_leaf(treedef)

如果 treedef 代表叶子,则返回 True。

treedef_tuple(treedefs)

从子 treedef 的可迭代对象中生成元组 treedef。

KeyEntry

类型变量。

KeyPath

别名 tuple[KeyEntry, ...]

keystr(keys)

帮助美化打印键的元组。

传统 API#

这些 API 现在可以通过 jax.tree 访问。

tree_all(tree, *[, is_leaf])

别名 jax.tree.all().

tree_flatten(tree[, is_leaf])

别名 jax.tree.flatten().

tree_leaves(tree[, is_leaf])

别名 jax.tree.leaves().

tree_map(f, tree, *rest[, is_leaf])

别名 jax.tree.map().

tree_reduce(function, tree[, initializer, ...])

别名 jax.tree.reduce().

tree_structure(tree[, is_leaf])

别名 jax.tree.structure().

tree_transpose(outer_treedef, inner_treedef, ...)

别名 jax.tree.transpose().

tree_unflatten(treedef, leaves)

别名 jax.tree.unflatten().