jax.tree_util.tree_structure#
- jax.tree_util.tree_structure(tree, is_leaf=None)[源代码]#
是
jax.tree.structure()
的别名。- 参数:
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
- 返回类型:
PyTreeDef
是 jax.tree.structure()
的别名。
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
PyTreeDef