jax.tree.leaves#
- jax.tree.leaves(tree, is_leaf=None)[源代码]#
获取 pytree 的叶子节点。
- 参数:
tree (Any) – 要获取叶节点的 pytree。
is_leaf (Callable[[Any], bool] | None | None) – 一个可选的函数,将在每个展平步骤调用。它应返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,并将整个子树视为叶节点。
- 返回:
树叶节点的列表。
- 返回类型:
leaves
示例
>>> import jax >>> jax.tree.leaves([1, (2, 3), [4, 5]]) [1, 2, 3, 4, 5]