jax.tree_util.treedef_tuple#
- jax.tree_util.treedef_tuple(treedefs)[source]#
从子 treedef 的可迭代对象中创建一个元组 treedef。
- 参数:
treedefs (Iterable[PyTreeDef]) – PyTree 结构的可迭代对象
- 返回值:
表示结构元组的单个 treedef
- 返回类型:
PyTreeDef
示例
>>> import jax >>> x = [1, 2, 3] >>> y = {'a': 4, 'b': 5} >>> x_tree = jax.tree.structure(x) >>> y_tree = jax.tree.structure(y) >>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree]) >>> xy_tree == jax.tree.structure((x, y)) True