jax.experimental.host_callback.barrier_wait#
- jax.experimental.host_callback.barrier_wait(logging_name=None)[source]#
阻止调用线程,直到所有当前的 outfeed 处理完成。
等待所有设备上正在运行的计算的所有回调都被 Python 回调接收并处理。如果在处理回调期间出现异常,则引发 CallbackException。
这是通过向我们正在监听 outfeed 的所有设备排队一个特殊的 tap 计算来实现的。一旦所有这些 tap 计算完成,我们就从 barrier_wait 返回。
注意:如果任何设备繁忙且无法接受新的计算,这将导致死锁。
- 参数:
logging_name (str | None | None) – 一个可选字符串,将在此调用的日志语句中使用。请参阅模块文档中的“调试”。
有关更多详细信息,请参阅
jax.experimental.host_callback
模块文档。