jax.profiler.annotate_function#

jax.profiler.annotate_function(func, name=None, **decorator_kwargs)[源代码]#

为函数执行生成跟踪事件的装饰器。

例如

>>> @jax.profiler.annotate_function
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> result = f(jnp.ones((1000, 1000)))

如果在 TensorBoard 跟踪进程时发生函数执行,这将导致 “f” 事件显示在跟踪时间线上。

可以通过 functools.partial() 将参数传递给装饰器。

>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name")
... def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))
参数:
  • func (Callable)

  • name (str | None | None)