jax.named_call#
- jax.named_call(fun, *, name=None)[source]#
在分阶段 JAX 计算时,为函数添加用户指定的名称。
在为即时编译到 XLA(或其他后端,如 TensorFlow)分阶段计算时,JAX 会运行你的 Python 程序,但默认情况下不会保留与之关联的任何函数名称或其他元数据。这会导致调试分阶段(或编译)的程序表示变得复杂,因为正在执行的每个操作的上下文信息有限。
named_call 告诉 JAX 将给定函数分阶段作为一个具有特定名称的子计算。当分阶段程序使用 XLA 编译时,这些命名子计算将被保留并在 TensorBoard 中的 TensorFlow Profiler 等调试实用程序中显示。在使用
experimental.jax2tf.convert()
将 JAX 程序分阶段到 TensorFlow 时,名称也会被保留。- 参数:
fun (F) – 要包装的函数。这可以是任何可调用对象。
name (str | None | None) – 可选。用于命名名称作用域内创建的所有子计算的名称前缀。如果未指定,则使用 fun.__name__。
- 返回值:
一个包装在 name_scope 中的 fun 版本。
- 返回类型:
F