jax.tree.transpose#
- jax.tree.transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#
将具有树结构 (outer, inner) 的树转换为具有结构 (inner, outer) 的树。
- 参数::
outer_treedef (tree_util.PyTreeDef) – 表示外层树的 PyTreeDef。
inner_treedef (tree_util.PyTreeDef | None) – 表示内层树的 PyTreeDef。如果为 None,则将从 outer_treedef 和 pytree_to_transpose 的结构中推断出来。
pytree_to_transpose (Any) – 要转置的 pytree。
- 返回值::
转置后的 pytree。
- 返回类型::
transposed_pytree
示例
>>> import jax >>> tree = [(1, 2, 3), (4, 5, 6)] >>> inner_structure = jax.tree.structure(('*', '*', '*')) >>> outer_structure = jax.tree.structure(['*', '*']) >>> jax.tree.transpose(outer_structure, inner_structure, tree) ([1, 4], [2, 5], [3, 6])
推断内部结构
>>> jax.tree.transpose(outer_structure, None, tree) ([1, 4], [2, 5], [3, 6])