jax.export.register_namedtuple_serialization#

jax.export.register_namedtuple_serialization(nodetype, *, serialized_name)[源代码]#

注册一个 namedtuple 以进行序列化和反序列化。

JAX 原生支持 collections.namedtuple 的 PyTree,不需要调用 jax.tree_util.register_pytree_node。但是,如果您想序列化具有 namedtuple 类型输入或输出的函数,则必须注册该类型以进行序列化。

参数:
  • nodetype (type[T]) – 我们要序列化的 PyTree 节点的类型。尝试为 nodetype 注册多个序列化是错误的。在反序列化时,此类型必须具有在序列化期间存在的相同键集。

  • serialized_name (str) – 一个字符串,将出现在序列化中,并将在反序列化期间用于查找注册。尝试为 serialized_name 注册多个序列化是错误的。

返回:

nodetype 传递的类型相同,以便此函数可以用作类装饰器。

返回类型:

type[T]