jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[源代码]#

返回直接子节点的 treedef 列表

参数:

treedef (PyTreeDef) – 单个 PyTreeDef

返回:

一个 PyTreeDef 列表,表示 treedef 的子节点。

返回类型:

list[PyTreeDef]

示例

>>> import jax
>>> x = [(1, 2), 3, {'a': 4}]
>>> treedef = jax.tree.structure(x)
>>> jax.tree_util.treedef_children(treedef)
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
>>> _ == [jax.tree.structure(vals) for vals in x]
True