jax.tree.reduce#
- jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) T [源代码]#
- jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, initializer: T, is_leaf: Callable[[Any], bool] | None = None) T
在树的叶子上调用 reduce() 函数。
- 参数:
function – 归约函数
tree – 要进行归约的 pytree
initializer – 可选的初始值
is_leaf – 一个可选的指定函数,它将在每个展平步骤中被调用。它应该返回一个布尔值,指示展平是否应该遍历当前对象,或者是否应该立即停止,并将整个子树视为叶子。
- 返回:
归约后的值。
- 返回类型:
结果
示例
>>> import jax >>> import operator >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) 21