jax.tree
模块#
用于处理类树容器数据结构的实用工具。
jax.tree
命名空间包含来自 jax.tree_util
的实用工具的别名。
函数列表#
|
在树的叶子上调用 all()。 |
|
扁平化 pytree。 |
|
像 |
|
获取 pytree 的叶子。 |
|
像 |
|
将一个多输入函数映射到 pytree 参数上,以生成一个新的 pytree。 |
|
将一个多输入函数映射到 pytree 的键路径和参数上,以生成一个新的 pytree。 |
|
在树的叶子上调用 reduce()。 |
|
获取 pytree 的 treedef。 |
|
将具有树结构(外部,内部)的树转换为具有结构(内部,外部)的树。 |
|
从 treedef 和叶子重建 pytree。 |