jax.tree.leaves_with_path#
- jax.tree.leaves_with_path(tree, is_leaf=None)[源代码]#
获取 pytree 的叶子,类似于
tree_leaves
,并返回每个叶子的键路径。- 参数:
tree (Any) – 一个 pytree。如果它包含自定义类型,建议使用
register_pytree_with_keys
注册。is_leaf (Callable[[Any], bool] | None | None)
- 返回:
一个键-叶对的列表,每个键-叶对包含一个叶子及其键路径。
- 返回类型:
示例
>>> import jax >>> jax.tree.leaves_with_path([1, {'x': 3}]) [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]