jax.tree_util.treedef_children#
- jax.tree_util.treedef_children(treedef)[source]#
返回直接子节点的 treedef 列表
- 参数:
**treedef** (PyTreeDef) – 单个 PyTreeDef
- 返回值:
表示 treedef 子节点的 PyTreeDef 列表。
- 返回类型:
list[PyTreeDef]
示例
>>> import jax >>> x = [(1, 2), 3, {'a': 4}] >>> treedef = jax.tree.structure(x) >>> jax.tree_util.treedef_children(treedef) [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] >>> _ == [jax.tree.structure(vals) for vals in x] True