jax.tree_util.all_leaves#

jax.tree_util.all_leaves(iterable, is_leaf=None)[源代码]#

测试给定可迭代对象中的所有元素是否都是叶子节点。

此函数在高级用例中很有用,例如,如果一个库允许对扁平的叶子节点可迭代对象进行任意的映射操作,它可能需要检查结果是否仍然是扁平的叶子节点可迭代对象。

参数:
  • iterable (Iterable[Any]) – 叶子节点的可迭代对象。

  • is_leaf (Callable[[Any], bool] | None | None)

返回:

一个布尔值,指示输入中的所有元素是否都是叶子节点。

返回类型:

bool

示例

>>> import jax
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])