jax.tree.flatten_with_path#

jax.tree.flatten_with_path(tree, is_leaf=None)[源代码]#

tree_flatten 一样展平一个 pytree,但也会返回每个叶子的键路径。

参数:
  • tree (Any) – 要展平的 pytree。如果它包含自定义类型,建议使用 register_pytree_with_keys 注册。

  • is_leaf (Callable[[Any], bool] | None | None)

返回:

一个包含两个元素的元组,第一个元素是键-叶对的列表,每个对包含一个叶子及其键路径。第二个元素是表示展平树结构的 treedef。

返回类型:

tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]

示例

>>> import jax
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
>>> path_vals
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
>>> treedef
PyTreeDef([*, {'x': *}])