jax.named_scope#
- jax.named_scope(name)[source]#
一个上下文管理器,它向 JAX 名称栈添加用户指定的名称。
在将计算分阶段进行即时编译到 XLA(或其他后端,如 TensorFlow)时,JAX 默认情况下不会保留它遇到的 Python 函数的名称(或其他源元数据)。这可能会使调试你程序的分阶段(和/或编译)表示变得复杂,因为正在执行的每个操作都只有有限的上下文信息。
named_scope
告诉 JAX 将给定函数分阶段进行,并对底层操作添加额外的注释。JAX 在内部使用名称栈跟踪这些注释。当分阶段的程序使用 XLA 编译时,这些注释将被保留,并出现在 TensorBoard 中的 TensorFlow Profiler 等调试工具中。在使用experimental.jax2tf.convert()
将 JAX 程序分阶段进行到 TensorFlow 时,名称也会被保留。- 参数:
name (str) – 用于命名名称范围内的所有创建操作的前缀。
- 收益率:
返回
None
,但会进入一个上下文,其中 name 将被追加到活动名称栈。- 返回类型:
Generator[None, None, None]
示例
named_scope
可用作编译函数内的上下文管理器>>> import jax >>> >>> @jax.jit ... def layer(w, x): ... with jax.named_scope("dot_product"): ... logits = w.dot(x) ... with jax.named_scope("activation"): ... return jax.nn.relu(logits)
它也可以用作装饰器
>>> @jax.jit ... @jax.named_scope("layer") ... def layer(w, x): ... logits = w.dot(x) ... return jax.nn.relu(logits)