jax.tree_util.tree_flatten_with_path

jax.tree_util.tree_flatten_with_path#

jax.tree_util.tree_flatten_with_path(tree, is_leaf=None)[source]#

tree_flatten 一样展平 pytree,但也会返回每个叶子的键路径。

参数:
  • tree (Any) – 要展平的 pytree。如果它包含自定义类型,则必须使用 register_pytree_with_keys 注册。

  • is_leaf (Callable[[Any], bool] | None | None)

返回值:

一个元组,其中第一个元素是一个键-叶对列表,每个对都包含一个叶子及其键路径。第二个元素是一个 treedef,表示展平树的结构。

返回类型:

tuple[list[tuple[KeyPath, Any]], PyTreeDef]