jax.named_scope

目录

jax.named_scope#

jax.named_scope(name)[source]#

一个上下文管理器,它向 JAX 名称栈添加用户指定的名称。

在将计算分阶段进行即时编译到 XLA(或其他后端,如 TensorFlow)时,JAX 默认情况下不会保留它遇到的 Python 函数的名称(或其他源元数据)。这可能会使调试你程序的分阶段(和/或编译)表示变得复杂,因为正在执行的每个操作都只有有限的上下文信息。

named_scope 告诉 JAX 将给定函数分阶段进行,并对底层操作添加额外的注释。JAX 在内部使用名称栈跟踪这些注释。当分阶段的程序使用 XLA 编译时,这些注释将被保留,并出现在 TensorBoard 中的 TensorFlow Profiler 等调试工具中。在使用 experimental.jax2tf.convert() 将 JAX 程序分阶段进行到 TensorFlow 时,名称也会被保留。

参数:

name (str) – 用于命名名称范围内的所有创建操作的前缀。

收益率:

返回 None,但会进入一个上下文,其中 name 将被追加到活动名称栈。

返回类型:

Generator[None, None, None]

示例

named_scope 可用作编译函数内的上下文管理器

>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
...   with jax.named_scope("dot_product"):
...     logits = w.dot(x)
...   with jax.named_scope("activation"):
...     return jax.nn.relu(logits)

它也可以用作装饰器

>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
...   logits = w.dot(x)
...   return jax.nn.relu(logits)