基于 JAX 构建#

学习高级 JAX 用法的一个好方法是了解其他库如何使用 JAX,包括它们如何将该库集成到其 API 中,它在数学上添加了哪些功能,以及如何在其他库中将其用于计算加速。

以下是如何使用 JAX 的功能来定义跨多个领域和软件包的加速计算的示例。

梯度计算#

简易梯度计算是 JAX 的一个关键特性。在 JaxOpt 库中,值和梯度直接用于用户,在 其源代码中的多种优化算法中。

同样,上面提到的 Dynamax Optax 配对是一个例子,说明梯度如何启用在历史上具有挑战性的估计方法 使用 Optax 的最大似然期望

在多个设备上实现单核计算加速#

在 JAX 中定义的模型可以通过 JIT 编译实现单次计算加速。相同的编译代码可以发送到 CPU 设备、GPU 或 TPU 设备以获得额外的加速,通常无需进行额外的更改。这使得从开发到生产的流程非常顺畅。在 Dynamax 中,线性状态空间模型求解器中计算量大的部分已经进行了 JIT 编译。一个更复杂的例子来自 PyTensor,它会动态编译一个 JAX 函数,然后 JIT 编译构造的函数

使用并行化的单机和多机加速#

JAX 的另一个好处是使用 pmapvmap 函数调用或装饰器来简化计算的并行化。在 Dynamax 中,状态空间模型使用 VMAP 装饰器 进行并行化,多目标跟踪是这种用例的一个实际例子。

将 JAX 代码集成到您或用户的 workflow 中#

JAX 具有很强的可组合性,可以以多种方式使用。JAX 可以以独立模式使用,用户可以在其中定义所有的计算。然而,其他模式,例如使用基于 JAX 构建的库来提供特定功能。这些可以是定义特定类型模型的库,例如神经网络或状态空间模型,或者提供特定功能的库,例如优化。以下是每种模式的更具体示例。

直接使用#

可以直接导入并使用 JAX 从头开始构建模型,如本网站所示,例如在 JAX 教程使用 JAX 的神经网络 中。如果您无法找到适合您特定挑战的预构建代码,或者如果您想减少代码库中的依赖项数量,这可能是最佳选择。

具有暴露的 JAX 的可组合领域特定库#

另一种常见的方法是提供预构建功能的包,无论是模型定义还是某种类型的计算。这些包的组合可以混合和匹配,以实现完整的端到端工作流程,在该流程中定义模型并估计其参数。

一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对使用,其中 Flax 定义神经网络架构,而 Optax 提供优化和模型拟合功能。

另一个是 Dynamax,它可以轻松定义状态空间模型。使用 Dynamax,可以使用 使用 Optax 的最大似然法 来估计参数,或者使用 Blackjax 的 MCMC 来估计完整的贝叶斯后验。

JAX 对用户完全隐藏#

其他库选择将 JAX 完全包裹在其特定于模型的 API 中。一个例子是 PyMC 和 Pytensor,用户可能永远不会直接“看到”JAX,而是用 PyMC 特定的 API 包裹 JAX 函数