jax.tree_util.build_tree

目录

jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#

从嵌套可迭代结构构建 treedef

参数:
  • treedef (PyTreeDef) – 要构建的 PyTreeDef 结构。

  • xs (Any) – 与 treedef 的元数匹配的嵌套可迭代。

返回值:

具有由 treedef 定义的结构的对象

返回类型:

Any

另请参阅

示例

>>> import jax
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
>>> treedef = jax.tree.structure(tree)

Both build_treejax.tree_util.tree_unflatten() 可以从新值重建树,但 build_tree 以嵌套结构而不是扁平结构的形式接受这些值。

>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]])
[(10, 11), {'a': 12, 'b': 13}]
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
[(10, 11), {'a': 12, 'b': 13}]