jax.tree_util.tree_transpose#
- jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#
jax.tree.transpose()
的别名。- 参数:
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
- 返回类型:
Any
jax.tree.transpose()
的别名。
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
Any