jax.tree_util.register_pytree_with_keys_class

jax.tree_util.register_pytree_with_keys_class#

jax.tree_util.register_pytree_with_keys_class(cls)[source]#

扩展被视为 pytree 中内部节点的类型集。

此函数类似于 register_pytree_node_class,但需要一个类来定义如何使用键对其进行展平。

它是 register_pytree_with_keys 的一个薄包装器,并提供了一个面向类的接口。

参数:

cls (Typ) – 要注册为 pytree 的类型

返回值:

在添加到 JAX 的 pytree 注册表后,输入类 cls 将保持不变。此返回值允许 register_pytree_node_class 用作装饰器。

返回类型:

Typ

另请参阅

示例

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> @register_pytree_with_keys_class
... class Special:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten_with_keys(self):
...     return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)