JEP 18137: JAX NumPy & SciPy wrappers 的范围#

Jake VanderPlas

2023 年 10 月

到目前为止,jax.numpyjax.scipy 的预期范围一直相对模糊。本文档提出这两个包的明确范围,以更好地指导和评估未来的贡献,并促进删除一些超出范围的代码。

背景#

从一开始,JAX 就旨在为在 XLA 中执行代码提供一个类似 NumPy 的 API,并且该项目开发的很大一部分是构建 jax.numpyjax.scipy 命名空间作为基于 JAX 的 NumPy 和 SciPy API 实现。一直以来,人们都隐含地理解,numpyscipy 的某些部分超出了 JAX 的范围,但是这个范围并没有得到很好的定义。这会导致贡献者感到困惑和沮丧,因为对于潜在的 jax.numpyjax.scipy 贡献是否会被 JAX 接受,没有明确的答案。

为什么要限制范围?#

为了避免这种没有说出口的事情,我们应该明确说明:事实是,包含在 JAX 这样的项目中的任何代码都会给开发人员带来微乎其微但并非零的持续维护负担。一个项目随时间的成功直接关系到维护人员是否能够继续维护项目的所有部分:记录功能、回答问题、修复错误等等。为了任何软件工具的长期成功和可持续性,维护人员必须仔细权衡任何特定贡献是否会在项目目标和资源方面对项目产生积极影响。

评估标准#

本文档提出六个轴的标准,可以根据这些标准判断任何特定 numpyscipy API 是否可以包含在 JAX 中。在所有轴上表现强大的 API 是包含在 JAX 包中的绝佳候选者;在六个轴中的任何一个轴上表现出明显的弱点,都是反对将该 API 包含在 JAX 中的有力论据。

轴 1:XLA 对齐#

我们考虑的第一个轴是所提议 API 与原生 XLA 操作对齐的程度。例如,jax.numpy.exp() 是一个函数,它或多或少地直接映射到 jax.lax.expnumpyscipy.specialnumpy.linalgscipy.linalg 等中的许多函数都满足此标准:在考虑将这些函数包含在 JAX 中时,这些函数通过了 XLA 对齐检查。

另一方面,有一些函数,比如 numpy.unique(),它们没有直接对应于任何 XLA 操作,并且在某些情况下,它们与 JAX 当前的计算模型(该模型要求使用静态形状的数组)从根本上不兼容(例如,unique 返回一个值相关的动态数组形状)。在考虑将这些函数包含在 JAX 中时,它们没有通过 XLA 对齐检查。

我们还将纯函数语义的需要视为该轴的一部分。例如,numpy.random 是基于一个隐式更新的基于状态的 RNG 构建的,这与 JAX 基于 XLA 构建的计算模型从根本上不兼容。

轴 2:数组 API 对齐#

我们考虑的第二个轴侧重于 Python 数组 API 标准:从某种意义上说,这是社区驱动的概述,说明了哪些数组操作对于广泛的用户社区的数组面向编程至关重要。如果 numpyscipy 中的 API 列在数组 API 标准中,则这是一个强烈的信号,表明 JAX 应该包含该 API。使用上面的例子,数组 API 标准包括 numpy.unique() 的几种变体(unique_allunique_countsunique_inverseunique_values),这表明,尽管该函数没有与 XLA 精确对齐,但它对 Python 用户社区足够重要,因此 JAX 可能应该实现它。

轴 3:下游实现的存在#

对于与轴 1 或 2 不对齐的功能,将其包含在 JAX 中的一个重要考虑因素是是否存在提供所讨论功能的良好支持的下游包。一个很好的例子是 scipy.optimize:虽然 JAX 确实包含一个最小的 scipy.optimize 功能包装器集,但 JAXopt 包提供了更加完整的处理方式,该包由 JAX 协作者积极维护。在这些情况下,我们应该倾向于指导用户和贡献者使用这些专门的包,而不是在 JAX 本身中重新实现这些 API。

轴 4:实现的复杂性和鲁棒性#

对于与 XLA 不对齐的功能,一个考虑因素是所提议实现的复杂程度。这在一定程度上与轴 1 相关,但无论如何都需要提出来。许多函数已贡献给 JAX,这些函数具有相对复杂的实现,这些实现难以验证并带来过大的维护负担;一个例子是 jax.scipy.special.bessel_jn():在撰写本文档时,其当前实现是一个非直接的迭代近似,在某些域中存在 收敛问题,并且 提出的修复方法 增加了进一步的复杂性。如果我们在接受贡献时更仔细地权衡实现的复杂性和鲁棒性,我们可能会选择不接受这个贡献到包中。

轴 5:函数式 API 与面向对象 API#

JAX 最适合使用函数式 API 而不是面向对象 API。面向对象 API 通常会隐藏不纯语义,这使得它们难以良好地实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便利包装器。

这方面的例子是 numpy.polynomial.Polynomial,它包装了低级操作,例如 numpy.polyadd()numpy.polydiv() 等等。一般来说,当同时存在函数式 API 和面向对象 API 时,JAX 应该避免提供面向对象 API 的包装器,而是提供函数式 API 的包装器。

在只存在面向对象 API 的情况下,JAX 应该避免提供包装器,除非在其他轴上的情况很强。

轴 6:对 JAX 用户和利益相关者的总体“重要性”#

在 JAX 中包含 NumPy/SciPy API 的决定还应考虑该算法对一般用户社区的重要性。承认难以量化谁是“利益相关者”以及如何衡量这种重要性;但是我们包含这一点是为了明确说明,关于在 JAX 的 NumPy 和 SciPy 包装器中包含什么内容的任何决定都将涉及一定程度的无法轻易量化的自由裁量权。

对于现有的 API,在 github 中搜索使用情况可能有助于确定重要性或缺乏重要性;例如,我们可能会回到上面讨论的 jax.scipy.special.bessel_jn():搜索显示该函数在 github 上只有 少量使用,这可能是由于前面提到的精度问题造成的。

评估:范围内的内容?#

在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy API,包括来自当前 JAX API 的一些示例。这不会是对所有现有函数和类的全面列举,而是一个更一般的按子模块和主题进行的讨论,并给出相关的示例。

NumPy API#

numpy 命名空间#

我们认为 numpy 命名空间中的主要函数本质上都在 JAX 的作用域内,这是因为它与 XLA(轴 1)和 Python 数组 API(轴 2)普遍一致,并且对 JAX 用户社区至关重要(轴 6)。一些函数可能处于边界线(例如 numpy.intersect1d()np.setdiff1d()np.union1d() 可能无法满足部分要求),但为了简单起见,我们声明主 numpy 命名空间中的所有数组函数都在 JAX 的作用域内。

numpy.linalg & numpy.fft#

numpy.linalgnumpy.fft 子模块包含许多与 XLA 提供的功能广泛一致的函数。其他函数具有复杂的设备特定降低,但代表了利益相关者重要性(轴 6)超过复杂性的案例。出于这个原因,我们认为这两个子模块都在 JAX 的作用域内。

numpy.random#

numpy.random 不在 JAX 的作用域内,因为基于状态的 RNG 本质上与 JAX 的计算模型不兼容。我们专注于 jax.random,它使用基于计数器的 PRNG 提供类似的功能。

numpy.ma & numpy.polynomial#

numpy.manumpy.polynomial 子模块主要关注为可以使用其他函数方式表达的计算提供面向对象的接口(轴 5);出于这个原因,我们认为它们不在 JAX 的作用域内。

numpy.testing#

NumPy 的测试功能只对主机端计算有意义,因此我们在 JAX 中不包含任何包装器。也就是说,JAX 数组与 numpy.testing 兼容,并且 JAX 在整个 JAX 测试套件中经常使用它。

SciPy API#

SciPy 在顶层命名空间中没有函数,但包含许多子模块。我们考虑以下每个子模块,省略了已弃用的模块。

scipy.cluster#

scipy.cluster 模块包含用于层次聚类、k 均值和相关算法的工具。这些工具在几个方面比较弱,最好由下游包来提供。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()),但在 github 上 没有明显的用法:这表明聚类对 JAX 用户并不重要。

建议:弃用并删除 jax.scipy.cluster.vq()

scipy.constants#

scipy.constants 模块包含数学和物理常数。这些常数可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它们。

scipy.datasets#

scipy.datasets 模块包含用于获取和加载数据集的工具。这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它们。

scipy.fft#

scipy.fft 模块包含与 XLA 提供的功能广泛一致的函数,并且在其他方面也表现良好。出于这个原因,我们认为它们在 JAX 的作用域内。

scipy.integrate#

scipy.integrate 模块包含用于数值积分的函数。这些函数中比较复杂的函数(quaddblquadode)不在 JAX 的作用域内(轴 1 和 4),因为它们往往是基于动态评估次数的循环算法。 jax.experimental.ode.odeint() 是相关的,但相当有限,并且没有积极开发。

JAX 目前确实包含 jax.scipy.integrate.trapezoid(),但这只是因为 numpy.trapz() 最近被弃用,取而代之的是它。对于任何特定输入,它的实现可以用一行 jax.numpy 表达式替换,因此它不是一个特别有用的 API 来提供。

基于轴 1、2、4 和 6,scipy.integrate 应该被认为不在 JAX 的作用域内。

建议:删除 jax.scipy.integrate.trapezoid(),它是在 JAX 0.4.14 中添加的。

scipy.interpolate#

scipy.interpolate 模块提供一维或多维插值的低级和面向对象的例程。这些 API 在上面的一些轴上的评分很低:它们是基于类的而不是低级的,除了最简单的方法之外,没有一种方法可以高效地用 XLA 操作表达。

JAX 目前确实包含 scipy.interpolate.RegularGridInterpolator 的包装器。如果我们今天考虑这个贡献,我们可能会根据上述标准拒绝它。但这段代码已经相当稳定,因此继续维护它没有什么坏处。

展望未来,我们应该考虑 scipy.interpolate 中的其他成员不在 JAX 的作用域内。

scipy.io#

scipy.io 子模块与文件输入/输出有关。没有理由在 JAX 中重新实现它。

scipy.linalg#

scipy.linalg 子模块包含与 XLA 提供的功能广泛一致的函数,快速线性代数对 JAX 用户社区至关重要。出于这个原因,我们认为它在 JAX 的作用域内。

scipy.ndimage#

scipy.ndimage 子模块包含一组用于处理图像数据的工具。其中许多与 scipy.signal 中的工具重叠(例如卷积和滤波)。JAX 目前提供了一个 scipy.ndimage API,位于 jax.scipy.ndimage.map_coordinates()。此外,JAX 在 jax.image 模块中提供了一些与图像相关的工具。DeepMind 生态系统包括 dm-pix,这是一个更全面的工具集,用于在 JAX 中进行图像处理。鉴于所有这些因素,我认为 scipy.ndimage 应该被认为是 JAX 核心范围之外的;我们可以将感兴趣的用户和贡献者引导到 dm-pix。我们可以考虑将 map_coordinates 移动到 dm-pix 或其他合适的包中。

scipy.odr#

scipy.odr 模块提供了围绕 ODRPACK 的面向对象包装器,用于执行正交距离回归。目前尚不清楚这是否可以用现有的 JAX 原语清楚地表达,因此我们认为它超出了 JAX 本身的作用域。

scipy.optimize#

scipy.optimize 模块提供了用于优化的高级和低级接口。这些功能对许多 JAX 用户来说非常重要,并且在 JAX 的早期就创建了 jax.scipy.optimize 包装器。然而,这些例程的开发者很快意识到 scipy.optimize API 过于限制,不同的团队开始开发 JAXopt 包和 Optimistix 包,每个包都包含在 JAX 中更全面、测试更好的优化例程集。

由于这些支持良好的外部包,我们现在认为 scipy.optimize 超出了 JAX 的范围。

建议:弃用 jax.scipy.optimize 或将其转换为围绕可选 JAXopt 或 Optimistix 依赖项的轻量级包装器。

🟡 scipy.signal#

scipy.signal 模块是混合的:一些函数完全属于 JAX 范围(例如 correlateconvolve,它们是 lax.conv_general_dilated 的更易于使用的包装器),而许多其他函数则完全超出了范围(特定领域的工具,没有可行的 XLA 降级路径)。对 jax.scipy.signal 的潜在贡献将不得不根据具体情况进行权衡。

🟡 scipy.sparse#

scipy.sparse 子模块主要包含用于以各种格式存储和操作稀疏矩阵和数组的数据结构。此外,scipy.sparse.linalg 包含许多无矩阵求解器,适用于稀疏矩阵、稠密矩阵和线性算子。

scipy.sparse 数组和矩阵数据结构超出了 JAX 的范围,因为它们与 JAX 的计算模型不一致(例如,许多操作依赖于动态大小的缓冲区)。JAX 开发了 jax.experimental.sparse 模块作为一组替代数据结构,它们更符合 JAX 的计算约束。由于这些原因,我们认为 scipy.sparse 中的数据结构超出了 JAX 的范围。

另一方面,scipy.sparse.linalg 被证明是一个有趣的领域,jax.scipy.sparse.linalg 包括 bicgstabcggmres 求解器。这些对 JAX 用户社区(轴 6)很有用,但除此之外,它们在其他轴上表现不佳。它们非常适合移入下游库;一个潜在的选择可能是 Lineax,它提供了许多基于 JAX 的线性求解器。

建议:探索将稀疏求解器移入 Lineax,否则将 `scipy.sparse`` 视为 JAX 范围之外。

scipy.spatial#

scipy.spatial 模块主要包含面向对象的空间/距离计算和最近邻搜索接口。它主要超出了 JAX 的范围。

scipy.spatial.transform 子模块提供了用于操作三维空间旋转的工具。它是一个相对复杂的、面向对象的接口,或许可以通过下游项目得到更好的服务。JAX 目前在 jax.scipy.spatial.transform 中包含了 RotationSlerp 的部分实现;这些都是其他基本函数的面向对象包装器,它们引入了非常大的 API 表面,并且用户很少。我们认为它们超出了 JAX 本身的范围,用户通过假设的下游项目得到更好的服务。

scipy.spatial.distance 子模块包含一组有用的距离度量,并且可能很想为这些度量提供 JAX 包装器。也就是说,使用 jit 和 vmap,用户可以轻松地根据需要从头开始定义大多数度量的有效版本,因此将它们添加到 JAX 中并没有特别大的好处。

建议:考虑弃用和删除 RotationSlerp API,并考虑将 scipy.spatial 整体视为未来贡献范围之外。

scipy.special#

scipy.special 模块包括对许多更专业函数的实现。在许多情况下,这些函数完全属于范围:例如,函数如 gammalnbetaincdigamma 以及许多其他函数直接对应于可用的 XLA 原语,并且根据轴 1 和其他因素明确属于范围。

其他函数需要更复杂的实现;上面提到的一个例子是 bessel_jn。尽管它们与轴 1 和 2 不一致,但这些函数往往在轴 6 上非常强大:scipy.special 提供了在各种领域中计算所必需的基本函数,因此即使是具有复杂实现的函数也应该倾向于属于范围,只要实现设计良好且稳健。

有一些现有的函数包装器,我们应该仔细研究;例如

  • jax.scipy.special.lpmn():这通过一个复杂的 fori_loop 生成勒让德多项式,这种方式与 scipy API 不匹配(例如,对于 scipyz 必须是标量,而对于 JAX,z 必须是一维数组)。该函数几乎没有可发现的用途,使其成为轴 1、2、4 和 6 的弱候选者。

  • jax.scipy.special.lpmn_values():这具有与上面 lmpn 相似的缺点。

  • jax.scipy.special.sph_harm():这是基于 lpmn 构建的,同样具有与相应的 scipy 函数不同的 API。

  • jax.scipy.special.bessel_jn():如上文 Axis 4 所述,此函数在实现稳健性和使用率方面存在弱点。我们可以考虑用新的、更稳健的实现来替换它(例如 #17038)。

建议:重构并改进 bessel_jn 的稳健性和测试覆盖率。如果无法修改 lpmnlpmn_valuessph_harm 以更接近 scipy API,则考虑弃用它们。

scipy.stats#

scipy.stats 模块包含广泛的统计函数,包括离散和连续分布、汇总统计量和假设检验。JAX 目前在 jax.scipy.stats 中包装了其中的一部分,主要包括大约 20 种统计分布,以及其他几个函数(moderankdatagaussian_kde)。总的来说,这些函数与 JAX 非常匹配:分布通常可以用高效的 XLA 操作来表示,并且 API 非常简洁且功能性。

我们目前没有针对假设检验函数的任何包装器,可能是因为这些函数对 JAX 的主要用户群来说不太有用。

关于分布,在某些情况下,tensorflow_probability 提供了类似的功能,将来我们可以考虑是否弃用 scipy.stats 中的分布,转而使用该实现。

建议:展望未来,我们应该将统计分布和汇总统计量视为范围内,而将假设检验和相关功能视为范围外。