jax.tree.flatten

内容

jax.tree.flatten#

jax.tree.flatten(tree, is_leaf=None)[source]#

将 pytree 展平。

展平顺序(即输出列表中元素的顺序)是确定的,对应于从左到右的深度优先树遍历。

参数:
  • **tree** (Any) – 要展平的 pytree。

  • **is_leaf** (Callable[[Any], bool] | None | None) – 可选指定的函数,将在每个展平步骤中被调用。它应该返回一个布尔值,如果为真则停止遍历并将整个子树视为叶子,如果为假则表示展平应该遍历当前对象。

返回:

一对,其中第一个元素是叶子值的列表,第二个元素是表示展平树结构的 treedef。

返回类型:

tuple[list[tree_util.Leaf], tree_util.PyTreeDef]

示例

>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> vals
[1, 2, 3, 4, 5]
>>> treedef
PyTreeDef([*, (*, *), [*, *]])