jax.tree.unflatten#
- jax.tree.unflatten(treedef, leaves)[source]#
根据 treedef 和叶子重建一个 pytree。
是
tree_flatten()
的逆操作。- 参数:
treedef (tree_util.PyTreeDef) – 要重建的 treedef
leaves (Iterable[tree_util.Leaf]) – 用于重建的叶子迭代器。迭代器必须与 treedef 的叶子匹配。
- 返回值:
重建的 pytree,包含放置在
treedef
描述的结构中的leaves
。- 返回类型:
Any
示例
>>> import jax >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) >>> newvals = [100, 200, 300, 400, 500] >>> jax.tree.unflatten(treedef, newvals) [100, (200, 300), [400, 500]]