jax.profiler.StepTraceAnnotation

jax.profiler.StepTraceAnnotation#

class jax.profiler.StepTraceAnnotation(name, **kwargs)[source]#

生成性能分析器中步骤跟踪事件的上下文管理器。

步骤跟踪事件跨越上下文包含的代码的持续时间。性能分析器将提供每个步骤跟踪事件的性能分析。

例如,它可用于标记训练步骤并使性能分析器能够提供每个步骤的性能分析。

>>> while global_step < NUM_STEPS:                                           
...   with jax.profiler.StepTraceAnnotation("train", step_num=global_step):  
...     train_step()                                                         
...     global_step += 1                                                     

如果事件发生在 TensorBoard 跟踪进程时,这将导致“train xx”事件显示在跟踪时间线上。此外,如果使用加速器,设备跟踪时间线也将显示“train xx”事件。请注意,“step_num”可以设置为关键字参数,以将全局步骤号传递给性能分析器。

参数:

name (str)

__init__(self, arg0: str, /, **kwargs) None[source]#
参数:

name (str)

方法

__init__(self, arg0, /, **kwargs)

属性

is_enabled

set_metadata