jax.tree.structure#
- jax.tree.structure(tree, is_leaf=None)[源代码]#
获取 pytree 的 treedef。
- 参数:
tree (Any) – 要获取叶子的 pytree
is_leaf (None | Callable[[Any], bool] | None) – 可选指定的函数,将在每个扁平化步骤中调用。它应该返回一个布尔值,该值指示是否应遍历当前对象进行扁平化,或者是否应立即停止,并将整个子树视为叶子。
- 返回:
一个表示树结构的 PyTreeDef。
- 返回类型:
pytreedef
示例
>>> import jax >>> jax.tree.structure([1, (2, 3), [4, 5]]) PyTreeDef([*, (*, *), [*, *]])