jax.tree_util.register_dataclass

jax.tree_util.register_dataclass#

jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields, drop_fields=())[source]#

扩展被视为 pytree 中内部节点的类型集。

这与 register_pytree_with_keys_class 不同,因为 C++ 注册表使用优化的 C++ 数据类内置函数而不是参数函数。

有关注册 pytree 的更多信息,请参阅 扩展 pytree

参数:
  • nodetype (Typ) – 要作为 pytree 内部节点处理的 Python 类型。假设它具有 dataclass 的语义:即,类属性代表对象的全部状态,并且可以作为关键字传递给类构造函数以创建对象的副本。所有定义的属性都应该在 meta_fieldsdata_fields 中列出。

  • meta_fields (Sequence[str]) – 辅助数据字段名。这些字段 *必须* 包含静态、可哈希、不可变的对象,因为这些对象用于生成 JIT 缓存键。特别是,meta_fields 不能包含 jax.Arraynumpy.ndarray 对象。

  • data_fields (Sequence[str]) – 数据字段名。这些字段 *必须* 是 JAX 兼容的对象,例如数组 (jax.Arraynumpy.ndarray)、标量,或叶节点为数组或标量的 pytree。请注意,data_fields 可以是 None,因为 JAX 将其识别为一个空 pytree。

  • drop_fields (Sequence[str])

返回:

在被添加到 JAX 的 pytree 注册表后,输入类 nodetype 会保持不变。此返回值允许 register_dataclass 被部分评估并用作装饰器,如以下示例所示。

返回类型:

Typ

示例

>>> from dataclasses import dataclass
>>> from functools import partial
>>>
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

现在这个类已被注册,它可以与 jax.tree_util 中的函数一起使用。

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

特别是,此注册允许 m 被无缝地传递到用 jax.jit() 和其他 JAX 变换封装的代码中。

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)