jax.tree_util.treedef_is_leaf#
- jax.tree_util.treedef_is_leaf(treedef)[源代码]#
如果 treedef 代表一个叶节点,则返回 True。
- 参数:
treedef (PyTreeDef) – 要检查的树结构。
- 返回值:
如果 treedef 是一个叶节点(即只有一个节点),则返回 True;否则返回 False。
- 返回类型:
示例
>>> import jax >>> tree1 = jax.tree.structure(1) >>> jax.tree_util.treedef_is_leaf(tree1) True >>> tree2 = jax.tree.structure([1, 2]) >>> jax.tree_util.treedef_is_leaf(tree2) False