基于 JAX#

学习高级 JAX 用法的一个好方法是查看其他库如何使用 JAX,包括它们如何将库集成到其 API 中,JAX 添加了哪些数学功能,以及如何在其他库中使用 JAX 来进行计算加速。

以下是关于如何在众多领域和软件包中使用 JAX 的功能来定义加速计算的示例。

梯度计算#

轻松的梯度计算是 JAX 的关键功能。在 JaxOpt 库 中,value 和 grad 直接用于用户在 其源代码 中的多个优化算法中。

同样,上面提到的 Dynamax Optax 配对是梯度使估计方法成为可能的示例,这些方法在历史上是具有挑战性的 使用 Optax 进行最大似然期望

在单个核心上跨多个设备进行计算加速#

在 JAX 中定义的模型随后可以被编译,以通过 JIT 编译实现单次计算加速。然后,相同的编译代码可以发送到 CPU 设备、GPU 设备或 TPU 设备以进一步加速,通常不需要进行任何其他更改。这使得从开发到生产的流程变得非常顺畅。在 Dynamax 中,线性状态空间模型求解器中计算量最大的部分已经被 jitted。一个更复杂的示例来自 PyTensor,它动态编译 JAX 函数,然后 jits 构造的函数

使用并行化进行单机和多机加速#

JAX 的另一个优点是使用 pmapvmap 函数调用或装饰器并行化计算的简单性。在 Dynamax 中,状态空间模型使用 VMAP 装饰器 并行化,多目标跟踪是这种用例的实际示例。

将 JAX 代码整合到您或您的用户的工作流程中#

JAX 具有高度可组合性,可以以多种方式使用。JAX 可以与独立模式一起使用,用户在其中自行定义所有计算。但是,还有其他模式,例如使用基于 JAX 的库来提供特定功能。这些库可以定义特定类型的模型,例如神经网络或状态空间模型,或者提供特定功能,例如优化。以下是每种模式的更具体示例。

直接使用#

如本网站所示,Jax 可以直接导入和利用来“从头开始”构建模型,例如在 JAX 教程使用 JAX 的神经网络 中。如果您无法找到针对特定挑战的预构建代码,或者您希望减少代码库中的依赖项,这可能是最佳选择。

具有 JAX 暴露的可组合域特定库#

另一种常见的方法是提供预构建功能的包,无论它是模型定义还是某种类型的计算。这些包的组合可以混合和匹配以实现完整的端到端工作流,其中定义模型并估计其参数。

一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对,其中 Flax 定义了神经网络架构,而 Optax 提供了优化和模型拟合功能。

另一个例子是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,可以使用 Optax 的最大似然 估计参数,或者可以使用 Blackjax 的 MCMC 估计完整的贝叶斯后验。

对用户完全隐藏的 JAX#

其他库选择在其特定模型 API 中完全包装 JAX。一个例子是 PyMC 和 Pytensor,在其中用户可能永远不会“看到”JAX,而是使用 PyMC 特定的 API 包装 JAX 函数