jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields, drop_fields=())[source]#
扩展被视为 pytree 中内部节点的类型集。
这与
register_pytree_with_keys_class
不同,因为 C++ 注册表使用优化的 C++ 数据类内置函数而不是参数函数。有关注册 pytree 的更多信息,请参阅 扩展 pytree。
- 参数:
nodetype (Typ) – 要作为 pytree 内部节点处理的 Python 类型。假设它具有
dataclass
的语义:即,类属性代表对象的全部状态,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应该在meta_fields
或data_fields
中列出。meta_fields (Sequence[str]) – 辅助数据字段名。这些字段 *必须* 包含静态、可哈希、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,
meta_fields
不能包含jax.Array
或numpy.ndarray
对象。data_fields (Sequence[str]) – 数据字段名。这些字段 *必须* 是 JAX 兼容的对象,例如数组 (
jax.Array
或numpy.ndarray
)、标量,或叶节点为数组或标量的 pytree。请注意,data_fields
可以是None
,因为 JAX 将其识别为一个空 pytree。drop_fields (Sequence[str])
- 返回:
在被添加到 JAX 的 pytree 注册表后,输入类
nodetype
会保持不变。此返回值允许register_dataclass
被部分评估并用作装饰器,如以下示例所示。- 返回类型:
Typ
示例
>>> from dataclasses import dataclass >>> from functools import partial >>> >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
现在这个类已被注册,它可以与
jax.tree_util
中的函数一起使用。>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特别是,此注册允许
m
被无缝地传递到用jax.jit()
和其他 JAX 变换封装的代码中。>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)