jax.tree_util.treedef_children

jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[source]#

返回直接子节点的 treedef 列表

参数:

**treedef** (PyTreeDef) – 单个 PyTreeDef

返回值:

表示 treedef 子节点的 PyTreeDef 列表。

返回类型:

list[PyTreeDef]

示例

>>> import jax
>>> x = [(1, 2), 3, {'a': 4}]
>>> treedef = jax.tree.structure(x)
>>> jax.tree_util.treedef_children(treedef)
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
>>> _ == [jax.tree.structure(vals) for vals in x]
True