jax.tree.map#
- jax.tree.map(f, tree, *rest, is_leaf=None)[source]#
将多输入函数映射到 pytree 参数,以生成新的 pytree。
- 参数:
f (Callable[..., Any]) – 该函数接受
1 + len(rest)
个参数,将在 pytrees 的相应叶子处应用。tree (Any) – 要映射的 pytree,每个叶子都为
f
提供第一个位置参数。rest (Any) – pytrees 的元组,每个 pytree 都有与
tree
相同的结构,或者以tree
作为前缀。is_leaf (Callable[[Any], bool] | None | None) – 可选指定的函数,将在每个扁平化步骤处调用。它应该返回一个布尔值,指示是否应该遍历当前对象,或者应该立即停止,并将整个子树视为叶子。
- 返回值:
一个新的 pytree,它与
tree
具有相同的结构,但每个叶子的值由f(x, *xs)
给出,其中x
是tree
中对应叶子的值,而xs
是rest
中对应节点的值元组。- 返回值类型:
Any
示例
>>> import jax >>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42}) {'x': 8, 'y': 43}
如果传递多个输入,树的结构将从第一个输入中获取;后续输入只需要在前面加上
tree
>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) [[5, 7, 9], [6, 1, 2]]