jax.experimental.host_callback.barrier_wait

jax.experimental.host_callback.barrier_wait#

jax.experimental.host_callback.barrier_wait(logging_name=None)[source]#

阻止调用线程,直到所有当前的 outfeed 处理完成。

等待所有设备上正在运行的计算的所有回调都被 Python 回调接收并处理。如果在处理回调期间出现异常,则引发 CallbackException。

这是通过向我们正在监听 outfeed 的所有设备排队一个特殊的 tap 计算来实现的。一旦所有这些 tap 计算完成,我们就从 barrier_wait 返回。

注意:如果任何设备繁忙且无法接受新的计算,这将导致死锁。

参数:

logging_name (str | None | None) – 一个可选字符串,将在此调用的日志语句中使用。请参阅模块文档中的“调试”。

有关更多详细信息,请参阅 jax.experimental.host_callback 模块文档。