jax.tree_util.register_static#
- jax.tree_util.register_static(cls)[source]#
将 cls 注册为一个没有叶节点的 pytree。
实例被
jax.jit()
、jax.pmap()
等视为静态。这可以作为使用jit
的static_argnums
和static_argnames
关键字参数、pmap
的static_broadcasted_argnums
等将输入标记为静态的替代方法。- 参数:
cls (type[H]) – 要注册为静态的类型。必须是可散列的,如 https://docs.pythonlang.cn/3/glossary.html#term-hashable 中定义。
- 返回值:
输入类
cls
在添加到 JAX 的 pytree 注册表后,将以不变的形式返回。这使得register_static
可以用作装饰器。- 返回值类型:
type[H]
示例
>>> import jax >>> @jax.tree_util.register_static ... class StaticStr(str): ... pass
现在,此静态字符串可以直接在
jax.jit()
编译的函数中使用,无需使用static_argnums
标记变量为静态。>>> @jax.jit ... def f(x, y, s): ... return x + y if s == 'add' else x - y ... >>> f(1, 2, StaticStr('add')) Array(3, dtype=int32, weak_type=True)