jax.named_call#

jax.named_call(fun, *, name=None)[源代码]#

在分阶段输出 JAX 计算时,为函数添加用户指定的名称。

当为即时编译到 XLA(或其他后端(如 TensorFlow))分阶段输出计算时,JAX 会运行您的 Python 程序,但默认情况下不会保留任何与它关联的函数名称或其他元数据。这会使调试程序的分阶段输出(和/或编译)表示变得复杂,因为每个正在执行的操作的上下文信息有限。

named_call 告诉 JAX 将给定函数作为具有特定名称的子计算分阶段输出。当使用 XLA 编译分阶段输出的程序时,这些命名的子计算会保留下来,并显示在调试实用程序中,例如 TensorBoard 中的 TensorFlow Profiler。当使用 experimental.jax2tf.convert() 将 JAX 程序分阶段输出到 TensorFlow 时,名称也会被保留。

参数:
  • fun (F) – 要包装的函数。这可以是任何可调用对象。

  • name (str | None | None) – 可选。用于命名在名称作用域内创建的所有子计算的前缀。如果未指定,则使用 fun.__name__。

返回:

一个被包装在 name_scope 中的 fun 版本。

返回类型:

F