jax.tree_util.tree_structure#
- jax.tree_util.tree_structure(tree, is_leaf=None)[source]#
jax.tree.structure()
的别名。- 参数:
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
- 返回类型:
PyTreeDef
jax.tree.structure()
的别名。
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
PyTreeDef