jax.tree.all#
- jax.tree.all(tree, *, is_leaf=None)[源代码]#
在树的叶子上调用 all()。
- 参数:
tree (Any) – 要评估的 pytree
is_leaf (Callable[[Any], bool] | None | None) – 一个可选指定的函数,将在每个展平步骤中调用。它应该返回一个布尔值,指示展平是否应该遍历当前对象,或者是否应该立即停止,并将整个子树视为叶子。
- 返回:
布尔值 True 或 False
- 返回类型:
结果
示例
>>> import jax >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) True >>> jax.tree.all([False, (True, False)]) False