jax.tree_util.all_leaves#
- jax.tree_util.all_leaves(iterable, is_leaf=None)[源代码]#
测试给定可迭代对象中的所有元素是否都是叶子节点。
此函数在高级情况下很有用,例如,如果一个库允许对叶子节点的扁平可迭代对象进行任意映射操作,它可能需要检查结果是否仍然是叶子节点的扁平可迭代对象。
- 参数:
iterable (Iterable[Any]) – 叶子节点的可迭代对象。
is_leaf (Callable[[Any], bool] | None | None)
- 返回:
一个布尔值,指示输入中的所有元素是否都是叶子节点。
- 返回类型:
示例
>>> import jax >>> tree = {"a": [1, 2, 3]} >>> assert all_leaves(jax.tree_util.tree_leaves(tree)) >>> assert not all_leaves([tree])