并发#

JAX 对 Python 并发的支持有限。

客户端可以从单独的 Python 线程并发调用 JAX API(例如,jit()grad())。

不允许从多个线程并发操作 JAX 跟踪值。换句话说,虽然可以从多个线程调用使用 JAX 跟踪的函数(例如,jit()),但您不能使用多线程来操作传递给 jit() 的函数 f 的实现内部的 JAX 值。 如果这样做,最有可能的结果是 JAX 出现一个神秘的错误。