jax.tree.map#

jax.tree.map(f, tree, *rest, is_leaf=None)[源代码]#

将多输入函数映射到 pytree 参数上,以生成新的 pytree。

参数
  • f (Callable[..., Any]) – 函数,接受 1 + len(rest) 个参数,在 pytree 的相应叶子上应用。

  • tree (Any) – 要映射的 pytree,其中每个叶子提供 f 的第一个位置参数。

  • rest (任意类型) – 一个 pytree 元组,其中每个 pytree 的结构都与 tree 相同,或者以 tree 作为前缀。

  • is_leaf (Callable[[任意类型], bool] | None | None) – 一个可选的指定函数,将在每个扁平化步骤中被调用。它应该返回一个布尔值,指示是否应该遍历当前对象进行扁平化,或者是否应该立即停止,将整个子树视为一个叶子节点。

返回:

一个新的 pytree,其结构与 tree 相同,但每个叶子节点的值由 f(x, *xs) 给出,其中 xtree 中相应叶子节点的值,xsrest 中对应节点的元组值。

返回类型:

任意类型

示例

>>> import jax
>>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}

如果传递多个输入,则树的结构取自第一个输入;后续输入只需要以 tree 作为前缀

>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]