jax.tree_util.register_static

jax.tree_util.register_static#

jax.tree_util.register_static(cls)[source]#

cls 注册为一个没有叶节点的 pytree。

实例被 jax.jit()jax.pmap() 等视为静态。这可以作为使用 jitstatic_argnumsstatic_argnames 关键字参数、pmapstatic_broadcasted_argnums 等将输入标记为静态的替代方法。

参数:

cls (type[H]) – 要注册为静态的类型。必须是可散列的,如 https://docs.pythonlang.cn/3/glossary.html#term-hashable 中定义。

返回值:

输入类 cls 在添加到 JAX 的 pytree 注册表后,将以不变的形式返回。这使得 register_static 可以用作装饰器。

返回值类型:

type[H]

示例

>>> import jax
>>> @jax.tree_util.register_static
... class StaticStr(str):
...   pass

现在,此静态字符串可以直接在 jax.jit() 编译的函数中使用,无需使用 static_argnums 标记变量为静态。

>>> @jax.jit
... def f(x, y, s):
...   return x + y if s == 'add' else x - y
...
>>> f(1, 2, StaticStr('add'))
Array(3, dtype=int32, weak_type=True)