jax.tree_util
模块
用于处理类似树的容器数据结构的实用程序。
此模块提供一小组用于处理类似树的数据结构的实用函数,例如嵌套的元组、列表和字典。我们将这些结构称为 pytrees。它们是树,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,并且任何 pytree 的 pytree 都是 pytree)并且可以递归地操作(映射操作不保留对象标识等价性,并且这些结构不能包含引用循环)。
被认为是 pytree 节点的 Python 类型集合(例如,可以被映射,而不是被视为叶子)是可扩展的。存在一个模块级的类型注册表,并且忽略类层次结构。通过注册一个新的 pytree 节点类型,该类型实际上对该文件中的实用函数变为透明。
此模块的主要目的是实现用户定义的数据结构和 JAX 转换(例如 jit)之间的互操作性。这并非旨在成为一个通用的树状数据结构处理库。
有关示例,请参阅JAX pytrees 说明。
函数列表
Partial (func, *args, **kw)
|
一个可在 pytree 中使用的 functools.partial 版本。 |
all_leaves (iterable[, is_leaf])
|
测试给定可迭代对象中的所有元素是否都是叶子。 |
build_tree (treedef, xs)
|
从嵌套的可迭代结构构建 treedef |
register_dataclass (nodetype[, data_fields, ...])
|
扩展在 pytree 中被认为是内部节点的类型集合。 |
register_pytree_node (nodetype, flatten_func, ...)
|
扩展在 pytree 中被认为是内部节点的类型集合。 |
register_pytree_node_class (cls)
|
扩展在 pytree 中被认为是内部节点的类型集合。 |
register_pytree_with_keys (nodetype, ...[, ...])
|
扩展在 pytree 中被认为是内部节点的类型集合。 |
register_pytree_with_keys_class (cls)
|
扩展在 pytree 中被认为是内部节点的类型集合。 |
register_static (cls)
|
将 cls 注册为没有叶子的 pytree。 |
tree_flatten_with_path (tree[, is_leaf])
|
jax.tree.flatten_with_path() 的别名。
|
tree_leaves_with_path (tree[, is_leaf])
|
jax.tree.leaves_with_path() 的别名。
|
tree_map_with_path (f, tree, *rest[, is_leaf])
|
jax.tree.map_with_path() 的别名。
|
treedef_children (treedef)
|
返回直接子级的 treedef 列表 |
treedef_is_leaf (treedef)
|
如果 treedef 表示一个叶子,则返回 True。 |
treedef_tuple (treedefs)
|
从子 treedef 的可迭代对象创建一个元组 treedef。 |
KeyEntry
|
类型变量。 |
KeyPath
|
tuple [KeyEntry , ...] 的别名
|
keystr (keys)
|
辅助函数,用于漂亮地打印键的元组。 |
旧版 API
这些 API 现在通过 jax.tree
访问。
tree_all (tree, *[, is_leaf])
|
jax.tree.all() 的别名。
|
tree_flatten (tree[, is_leaf])
|
jax.tree.flatten() 的别名。
|
tree_leaves (tree[, is_leaf])
|
jax.tree.leaves() 的别名。
|
tree_map (f, tree, *rest[, is_leaf])
|
jax.tree.map() 的别名。
|
tree_reduce (function, tree[, initializer, ...])
|
jax.tree.reduce() 的别名。
|
tree_structure (tree[, is_leaf])
|
jax.tree.structure() 的别名。
|
tree_transpose (outer_treedef, inner_treedef, ...)
|
jax.tree.transpose() 的别名。
|
tree_unflatten (treedef, leaves)
|
jax.tree.unflatten() 的别名。
|