jax.export.register_pytree_node_serialization#

jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[源代码]#

注册自定义 PyTree 节点的序列化和反序列化。

在对非原生支持类型的 PyTree 节点进行序列化和反序列化之前,您必须使用此函数。 我们序列化 Exportedin_treeout_tree 字段的 PyTree 节点,它们是导出函数的调用约定的一部分。

此函数必须在调用 jax.tree_util.register_pytree_node 之后调用(除了 collections.namedtuple,它不需要调用 register_pytree_node)。

参数:
  • nodetype (类型[T]) – 我们要序列化的 PyTree 节点的类型。尝试为 nodetype 注册多个序列化是错误的。

  • serialized_name (字符串) – 一个字符串,它将出现在序列化中,并在反序列化期间用于查找注册。 尝试为 serialized_name 注册多个序列化是错误的。

  • serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 辅助数据(由 jax.tree_util.register_pytree_nodeflatten_func 参数返回)。

  • deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的辅助数据。

  • from_children (_BuildFromChildren | None | None) – 如果存在,这是一个函数,它接受 deserialize_auxdata 的结果以及一些子节点,并创建一个 nodetype 的实例。 这类似于传递给 jax.tree_util.register_pytree_nodeunflatten_func。 如果不存在,我们将查找并使用 unflatten_func。 这对于 collections.namedtuple 是必需的,它没有 register_pytree_node,但覆盖该函数可能很有用。 请注意,from_children 的结果仅与 jax.tree_util.tree_structure 一起使用以构造适当的 PyTree 节点,它不用于构造序列化函数的输出。

返回:

nodetype 传递的类型相同,因此此函数可以用作类装饰器。

返回类型:

类型[T]