jax.tree_util.treedef_tuple

jax.tree_util.treedef_tuple#

jax.tree_util.treedef_tuple(treedefs)[source]#

从子 treedef 的可迭代对象中创建一个元组 treedef。

参数:

treedefs (Iterable[PyTreeDef]) – PyTree 结构的可迭代对象

返回值:

表示结构元组的单个 treedef

返回类型:

PyTreeDef

示例

>>> import jax
>>> x = [1, 2, 3]
>>> y = {'a': 4, 'b': 5}
>>> x_tree = jax.tree.structure(x)
>>> y_tree = jax.tree.structure(y)
>>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree])
>>> xy_tree == jax.tree.structure((x, y))
True