jax.tree_util.register_pytree_with_keys_class#

jax.tree_util.register_pytree_with_keys_class(cls)[源代码]#

扩展了在 pytrees 中被视为内部节点的类型集。

此函数类似于 register_pytree_node_class,但需要一个定义如何使用键展平的类。

它是 register_pytree_with_keys 的一个薄包装器,并提供一个面向类的接口

参数:

cls (Typ) – 要注册为 pytree 的类型

返回:

输入类 cls 在添加到 JAX 的 pytree 注册表后会保持不变地返回。此返回值允许将 register_pytree_node_class 用作装饰器。

返回类型:

Typ

另请参阅

示例

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> @register_pytree_with_keys_class
... class Special:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten_with_keys(self):
...     return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)