jax.tree_util.register_pytree_with_keys_class#
- jax.tree_util.register_pytree_with_keys_class(cls)[源代码]#
扩展了在 pytrees 中被视为内部节点的类型集。
此函数类似于
register_pytree_node_class
,但需要一个定义如何使用键展平的类。它是
register_pytree_with_keys
的一个薄包装器,并提供一个面向类的接口- 参数:
cls (Typ) – 要注册为 pytree 的类型
- 返回:
输入类
cls
在添加到 JAX 的 pytree 注册表后会保持不变地返回。此返回值允许将register_pytree_node_class
用作装饰器。- 返回类型:
Typ
另请参阅
register_static()
:用于注册静态 pytree 的更简单的 API。register_dataclass()
:用于注册数据类的更简单 API。
示例
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey >>> @register_pytree_with_keys_class ... class Special: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten_with_keys(self): ... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children)