jax.export.register_pytree_node_serialization#
- jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[源代码]#
注册自定义 PyTree 节点的序列化和反序列化。
在对非原生支持类型的 PyTree 节点进行序列化和反序列化之前,您必须使用此函数。 我们序列化 Exported 的 in_tree 和 out_tree 字段的 PyTree 节点,它们是导出函数的调用约定的一部分。
此函数必须在调用 jax.tree_util.register_pytree_node 之后调用(除了 collections.namedtuple,它不需要调用 register_pytree_node)。
- 参数:
nodetype (类型[T]) – 我们要序列化的 PyTree 节点的类型。尝试为 nodetype 注册多个序列化是错误的。
serialized_name (字符串) – 一个字符串,它将出现在序列化中,并在反序列化期间用于查找注册。 尝试为 serialized_name 注册多个序列化是错误的。
serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 辅助数据(由 jax.tree_util.register_pytree_node 的 flatten_func 参数返回)。
deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的辅助数据。
from_children (_BuildFromChildren | None | None) – 如果存在,这是一个函数,它接受 deserialize_auxdata 的结果以及一些子节点,并创建一个 nodetype 的实例。 这类似于传递给 jax.tree_util.register_pytree_node 的 unflatten_func。 如果不存在,我们将查找并使用 unflatten_func。 这对于 collections.namedtuple 是必需的,它没有 register_pytree_node,但覆盖该函数可能很有用。 请注意,from_children 的结果仅与 jax.tree_util.tree_structure 一起使用以构造适当的 PyTree 节点,它不用于构造序列化函数的输出。
- 返回:
与 nodetype 传递的类型相同,因此此函数可以用作类装饰器。
- 返回类型:
类型[T]