jax.tree_util.tree_unflatten# jax.tree_util.tree_unflatten(treedef, leaves)[source]# jax.tree.unflatten() 的别名。 参数: treedef (PyTreeDef) leaves (Iterable[Leaf]) 返回值: Any