jax.tree_util.tree_leaves_with_path# jax.tree_util.tree_leaves_with_path(tree, is_leaf=None)[source]# 获取 pytree 的叶子,类似于 tree_leaves,并返回每个叶子的键路径。 参数: tree (Any) – 一个 pytree。如果它包含自定义类型,则必须使用 register_pytree_with_keys 注册。 is_leaf (Callable[[Any], bool] | None | None) 返回值: 键-叶子对列表,每个对包含一个叶子及其键路径。 返回类型: list[tuple[KeyPath, Any]] 另请参阅 jax.tree_util.tree_leaves() jax.tree_util.tree_flatten_with_path()