jax.tree_util.tree_unflatten# jax.tree_util.tree_unflatten(treedef, leaves)[源代码]# jax.tree.unflatten()的别名。 参数: treedef (PyTreeDef) leaves (Iterable[Leaf]) 返回类型: Any