jax.tree.reduce

内容

jax.tree.reduce#

jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) T[source]#
jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, initializer: T, is_leaf: Callable[[Any], bool] | None = None) T

对树的叶子节点调用 reduce() 函数。

参数:
  • function – 归约函数

  • tree – 要归约的 PyTree

  • initializer – 可选的初始值

  • is_leaf – 可选指定的函数,将在每个扁平化步骤中被调用。它应该返回一个布尔值,指示是否应该遍历当前对象进行扁平化,或者是否应该立即停止,将整个子树视为一个叶子节点。

返回值:

归约后的值。

返回类型:

结果

示例

>>> import jax
>>> import operator
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
21