jax.tree_util.treedef_children#
- jax.tree_util.treedef_children(treedef)[源代码]#
返回直接子节点的 treedef 列表
- 参数:
treedef (PyTreeDef) – 单个 PyTreeDef
- 返回:
一个 PyTreeDef 列表,表示 treedef 的子节点。
- 返回类型:
list[PyTreeDef]
示例
>>> import jax >>> x = [(1, 2), 3, {'a': 4}] >>> treedef = jax.tree.structure(x) >>> jax.tree_util.treedef_children(treedef) [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] >>> _ == [jax.tree.structure(vals) for vals in x] True