JEP 18137:JAX NumPy 和 SciPy 包装器的范围#
Jake VanderPlas
2023 年 10 月
到目前为止,jax.numpy
和 jax.scipy
的预期范围相对不明确。本文档为这些软件包提出了明确的范围,以更好地指导和评估未来的贡献,并促使删除一些超出范围的代码。
背景#
从一开始,JAX 的目标就是为在 XLA 中执行代码提供一个类似 NumPy 的 API,并且该项目开发的重要组成部分是构建 jax.numpy
和 jax.scipy
命名空间,作为基于 JAX 的 NumPy 和 SciPy API 实现。一直以来都有一个隐含的理解,即 numpy
和 scipy
的某些部分超出了 JAX 的范围,但此范围并未明确定义。这可能会导致贡献者的困惑和沮丧,因为对于潜在的 jax.numpy
和 jax.scipy
贡献是否会被 JAX 接受,没有明确的答案。
为什么要限制范围?#
为了避免不言自明,我们应该明确指出:事实是,任何包含在 JAX 之类项目中的代码都会给开发人员带来小但非零的持续维护负担。一个项目随着时间的推移取得成功,直接关系到维护人员为项目所有部分的总和继续进行维护的能力:记录功能、回复问题、修复错误等。为了任何软件工具的长期成功和可持续性,维护人员仔细权衡任何特定贡献是否会根据其目标和资源为项目带来净效益至关重要。
评估标准#
本文档提出了一个由六个轴组成的标准,可以根据这些轴来判断任何特定的 numpy
或 scipy
API 是否可以包含在 JAX 中。在所有轴上都很强的 API 是包含在 JAX 包中的绝佳候选者;在六个轴中的任何一个轴上存在严重弱点都是反对包含在 JAX 中的有力论据。
轴 1:XLA 对齐#
我们要考虑的第一个轴是建议的 API 与原生 XLA 操作对齐的程度。例如,jax.numpy.exp()
函数或多或少直接镜像 jax.lax.exp
。 numpy
、scipy.special
、numpy.linalg
、scipy.linalg
和其他许多函数都满足此标准:在考虑将这些函数包含到 JAX 中时,它们通过了 XLA 对齐检查。
另一方面,存在像 numpy.unique()
这样的函数,它们不直接对应于任何 XLA 操作,并且在某些情况下,从根本上与 JAX 当前的计算模型不兼容,该模型需要静态形状的数组(例如,unique
返回一个依赖于值的动态数组形状)。在考虑将这些函数包含到 JAX 中时,它们未通过 XLA 对齐检查。
作为此轴的一部分,我们还考虑了对纯函数语义的需求。例如,numpy.random
构建在隐式更新的基于状态的 RNG 之上,这从根本上与基于 XLA 构建的 JAX 计算模型不兼容。
轴 2:Array API 对齐#
我们要考虑的第二个轴侧重于 Python Array API 标准:从某种意义上说,这是社区驱动的概述,其中说明了哪些数组操作对于各种用户社区中的面向数组的编程至关重要。如果 numpy
或 scipy
中的 API 列在 Array API 标准中,则强烈暗示 JAX 应该包含它。使用上面的示例,Array API 标准包括 numpy.unique()
的几个变体(unique_all
、unique_counts
、unique_inverse
、unique_values
),这表明,尽管该函数与 XLA 的对齐并不精确,但它对于 Python 用户社区来说足够重要,JAX 或许应该实现它。
轴 3:下游实现的存在#
对于与轴 1 或 2 不一致的功能,是否支持良好的下游软件包提供所讨论的功能是包含在 JAX 中的一个重要考虑因素。一个很好的例子是 scipy.optimize
:尽管 JAX 确实包含了一组最少的 scipy.optimize
功能的包装器,但在 JAXopt 软件包中存在更完整的处理方式,该软件包由 JAX 合作者积极维护。在这种情况下,我们应该倾向于将用户和贡献者指向这些专门的软件包,而不是在 JAX 本身中重新实现此类 API。
轴 4:实现的复杂性与稳健性#
对于与 XLA 不一致的功能,一个考虑因素是建议的实现方案的复杂程度。这在某种程度上与轴 1 一致,但仍然很重要。许多函数被贡献给了 JAX,它们的实现方案相对复杂,难以验证并引入了过大的维护负担;一个例子是 jax.scipy.special.bessel_jn()
:在编写此 JEP 时,其当前实现是一个不直接的迭代近似,在某些领域存在收敛问题,并且建议的修复引入了进一步的复杂性。如果在接受贡献时,我们更加仔细地权衡了实现的复杂性和稳健性,我们可能会选择不接受对该软件包的贡献。
轴 5:函数式 API 与面向对象 API#
JAX 最适合函数式 API,而不是面向对象 API。面向对象 API 通常会隐藏不纯的语义,使得它们通常难以很好地实现。NumPy 和 SciPy 通常坚持使用函数式 API,但有时会提供面向对象的便捷包装器。
这方面的例子是 numpy.polynomial.Polynomial
,它包装了诸如 numpy.polyadd()
、numpy.polydiv()
等的底层操作。一般来说,当同时提供函数式和面向对象的 API 时,JAX 应该避免为面向对象的 API 提供包装器,而是为函数式 API 提供包装器。
在仅存在面向对象的 API 的情况下,除非在其他方面有强有力的理由,否则 JAX 应避免提供包装器。
轴 6:对 JAX 用户和利益相关者的总体“重要性”#
在 JAX 中包含 NumPy/SciPy API 的决定也应考虑到该算法对一般用户社区的重要性。 诚然,很难量化谁是“利益相关者”以及如何衡量这种重要性; 但我们纳入这一点是为了明确,关于在 JAX 的 NumPy 和 SciPy 包装器中包含什么内容的任何决定都将涉及一定程度的酌处权,而这无法轻易量化。
对于现有的 API,在 github 中搜索使用情况可能有助于确定其重要性或缺乏重要性; 例如,我们可以回到上面讨论的 jax.scipy.special.bessel_jn()
:搜索结果表明,此函数在 github 上只有 少量使用,这可能部分与之前提到的准确性问题有关。
评估:范围是什么?#
在本节中,我们将尝试根据上述规则评估 NumPy 和 SciPy API,包括当前 JAX API 中的一些示例。 这不会是所有现有函数和类的全面列表,而是按子模块和主题进行更一般的讨论,并提供相关示例。
NumPy API#
✅ numpy
命名空间#
我们认为主 numpy
命名空间中的函数基本上都在 JAX 的范围内,因为它与 XLA(轴 1)和 Python 数组 API(轴 2)的通用对齐,以及它对 JAX 用户社区的总体重要性(轴 6)。 一些函数可能处于边缘(例如 numpy.intersect1d()
、np.setdiff1d()
、np.union1d()
这样的函数可以说不符合规则的某些部分),但为了简单起见,我们声明主 numpy 命名空间中的所有数组函数都在 JAX 的范围内。
✅ numpy.linalg
& numpy.fft
#
numpy.linalg
和 numpy.fft
子模块包含许多与 XLA 提供的功能广泛一致的函数。 其他函数具有复杂的特定于设备的底层实现,但代表了利益相关者的重要性(轴 6)大于复杂性的情况。 因此,我们认为这两个子模块都在 JAX 的范围内。
❌ numpy.random
#
numpy.random
超出了 JAX 的范围,因为基于状态的 RNG 从根本上与 JAX 的计算模型不兼容。 我们转而关注 jax.random
,它使用基于计数器的 PRNG 提供类似的功能。
❌ numpy.ma
& numpy.polynomial
#
numpy.ma
和 numpy.polynomial
子模块主要关注为可以通过其他函数方式表达的计算提供面向对象的接口(轴 5); 因此,我们认为它们超出了 JAX 的范围。
❌ numpy.testing
#
NumPy 的测试功能仅对主机端计算有意义,因此我们不会在 JAX 中包含任何它的包装器。 尽管如此,JAX 数组与 numpy.testing
兼容,并且 JAX 在整个 JAX 测试套件中频繁使用它。
SciPy API#
SciPy 在顶层命名空间中没有函数,但包含许多子模块。 我们将在下面考虑每个子模块,并排除已弃用的模块。
❌ scipy.cluster
#
scipy.cluster
模块包括用于层次聚类、k 均值和相关算法的工具。 这些在多个轴上表现较弱,最好由下游软件包提供。 JAX 中已存在一个函数 (jax.scipy.cluster.vq.vq()
),但在 github 上 没有明显的使用:这表明聚类对 JAX 用户来说并不重要。
建议:弃用并删除 jax.scipy.cluster.vq()
。
❌ scipy.constants
#
scipy.constants
模块包括数学和物理常数。 这些常量可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。
❌ scipy.datasets
#
scipy.datasets
模块包括用于获取和加载数据集的工具。 这些获取的数据集可以直接与 JAX 一起使用,因此没有理由在 JAX 中重新实现它。
✅ scipy.fft
#
scipy.fft
模块包含与 XLA 提供的功能广泛一致的函数,并且在其他轴上也表现良好。 因此,我们认为它们在 JAX 的范围内。
❌ scipy.integrate
#
scipy.integrate
模块包含用于数值积分的函数。 其中更复杂的函数(quad
、dblquad
、ode
)由于倾向于基于动态评估次数的循环算法,因此在轴 1 和 4 上都超出了 JAX 的范围。jax.experimental.ode.odeint()
是相关的,但相当有限,并且未处于任何积极的开发中。
JAX 目前确实包含 jax.scipy.integrate.trapezoid()
,但这仅仅是因为最近 numpy.trapz()
已被弃用而改用它。 对于任何特定输入,它的实现都可以用一行 jax.numpy
表达式替换,因此提供它并不是一个特别有用的 API。
根据轴 1、2、4 和 6,scipy.integrate
应被视为超出 JAX 的范围。
建议:删除在 JAX 0.4.14 中添加的 jax.scipy.integrate.trapezoid()
。
❌ scipy.interpolate
#
scipy.interpolate
模块提供了一维或多维插值的底层和面向对象的例程。这些 API 在上述多个方面表现不佳:它们是基于类的而不是底层的,并且除了最简单的方法外,没有其他方法可以有效地用 XLA 操作表示。
JAX 目前确实有 scipy.interpolate.RegularGridInterpolator
的包装器。如果我们今天考虑这项贡献,我们可能会根据上述标准拒绝它。但此代码一直相当稳定,因此继续维护它并没有太多缺点。
展望未来,我们应将 scipy.interpolate
的其他成员视为 JAX 的范围之外。
❌ scipy.io
#
scipy.io
子模块与文件输入/输出有关。没有理由在 JAX 中重新实现它。
✅ scipy.linalg
#
scipy.linalg
子模块包含的功能与 XLA 提供的功能大致一致,并且快速线性代数对 JAX 用户社区非常重要。因此,我们认为它在 JAX 的范围内。
❌ scipy.ndimage
#
scipy.ndimage
子模块包含一组用于处理图像数据的工具。其中许多工具与 scipy.signal
中的工具重叠(例如,卷积和滤波)。JAX 目前在 jax.scipy.ndimage.map_coordinates()
中提供了一个 scipy.ndimage
API。此外,JAX 在 jax.image
模块中提供了一些与图像相关的工具。DeepMind 生态系统包括 dm-pix,这是一组更全面的 JAX 图像操作工具。考虑到所有这些因素,我认为应该将 scipy.ndimage
视为 JAX 核心的范围之外;我们可以将感兴趣的用户和贡献者指向 dm-pix。我们可以考虑将 map_coordinates
移动到 dm-pix
或其他合适的包中。
❌ scipy.odr
#
scipy.odr
模块提供了一个围绕 ODRPACK
的面向对象的包装器,用于执行正交距离回归。目前还不清楚这是否可以使用现有的 JAX 原语干净地表达,因此我们认为它在 JAX 本身的范围之外。
❌ scipy.optimize
#
scipy.optimize
模块为优化提供了高级和低级接口。此类功能对许多 JAX 用户非常重要,并且 JAX 很早就创建了 jax.scipy.optimize
包装器。但是,这些例程的开发人员很快意识到 scipy.optimize
API 太过受限,不同的团队开始开发 JAXopt 包和 Optimistix 包,每个包都包含一组更加全面且经过更好测试的 JAX 优化例程。
由于这些有良好支持的外部包,我们现在认为 scipy.optimize
超出 JAX 的范围。
建议:弃用 jax.scipy.optimize
和/或使其成为围绕可选 JAXopt 或 Optimistix 依赖项的轻量级包装器。
🟡 scipy.signal
#
scipy.signal
模块是混合的:某些函数完全在 JAX 的范围内(例如,correlate
和 convolve
,它们是 lax.conv_general_dilated
更用户友好的包装器),而许多其他函数完全超出范围(没有可行的 XLA 降级路径的特定领域工具)。对 jax.scipy.signal
的潜在贡献必须根据具体情况进行权衡。
🟡 scipy.sparse
#
scipy.sparse
子模块主要包含用于以各种格式存储和操作稀疏矩阵和数组的数据结构。此外,scipy.sparse.linalg
包含许多无矩阵求解器,适用于稀疏矩阵、密集矩阵和线性算子。
scipy.sparse
数组和矩阵数据结构超出 JAX 的范围,因为它们与 JAX 的计算模型不一致(例如,许多操作依赖于动态大小的缓冲区)。JAX 开发了 jax.experimental.sparse
模块作为一组与 JAX 计算约束更加一致的替代数据结构。由于这些原因,我们认为 scipy.sparse
中的数据结构超出 JAX 的范围。
另一方面,scipy.sparse.linalg
证明是一个有趣的领域,并且 jax.scipy.sparse.linalg
包括 bicgstab
、cg
和 gmres
求解器。这些对 JAX 用户社区(第 6 轴)很有用,但除此之外,在其他方面表现不佳。它们非常适合移动到下游库中;一个潜在的选择可能是 Lineax,它具有许多基于 JAX 构建的线性求解器。
建议:探索将稀疏求解器移动到 Lineax 中,否则将 scipy.sparse
视为 JAX 的范围之外。
❌ scipy.spatial
#
scipy.spatial
模块主要包含空间/距离计算和最近邻搜索的面向对象的接口。它大多不在 JAX 的范围内。
scipy.spatial.transform
子模块提供了用于操作三维空间旋转的工具。它是一个相对复杂的面向对象的接口,也许可以由下游项目更好地服务。JAX 目前在 jax.scipy.spatial.transform
中包含了 Rotation
和 Slerp
的部分实现;这些是面向对象的包装器,它们是基本函数,引入了非常大的 API 表面并且用户很少。我们判断它们不在 JAX 本身的范围内,用户最好使用假设的下游项目。
scipy.spatial.distance
子模块包含有用的距离度量集合,并且可能很想为这些度量提供 JAX 包装器。也就是说,通过 jit 和 vmap,用户可以轻松地从头开始定义大多数这些度量的有效版本,如果需要的话,因此将它们添加到 JAX 中并没有特别的好处。
建议:考虑弃用并移除 Rotation
和 Slerp
API,并考虑将 scipy.spatial
整个模块排除在未来的贡献范围之外。
✅ scipy.special
#
scipy.special
模块包含许多更专业函数的实现。在许多情况下,这些函数完全在范围内:例如,像 gammaln
、betainc
、digamma
等许多函数直接对应于可用的 XLA 原语,并且根据轴 1 和其他轴的定义,显然在范围内。
其他函数需要更复杂的实现;上面提到的一个例子是 bessel_jn
。尽管与轴 1 和 2 不一致,但这些函数往往在轴 6 上表现非常出色:scipy.special
提供了各种领域计算所需的基本函数,因此即使是实现复杂的函数也应倾向于纳入范围内,只要实现设计良好且健壮即可。
有一些现有的函数包装器,我们应该仔细研究一下;例如
jax.scipy.special.lpmn()
:这通过一个复杂的 fori_loop 生成勒让德多项式,其方式与 scipy API 不匹配(例如,对于scipy
,z
必须是标量,而对于 JAX,z
必须是一维数组)。该函数的使用场景很少,使其在轴 1、2、4 和 6 上成为一个弱候选项。jax.scipy.special.lpmn_values()
:这与上面的lmpn
有相似的弱点。jax.scipy.special.sph_harm()
:这是建立在 lpmn 之上的,并且类似地,其 API 与相应的scipy
函数不同。jax.scipy.special.bessel_jn()
:如上文在轴 4 下讨论的,这在实现的稳健性和使用率方面存在弱点。我们可能会考虑用新的、更健壮的实现来替换它(例如 #17038)。
建议:重构并提高 bessel_jn
的稳健性和测试覆盖率。如果 lpmn
、lpmn_values
和 sph_harm
无法修改为更紧密地匹配 scipy
API,则考虑弃用它们。
✅ scipy.stats
#
scipy.stats
模块包含各种统计函数,包括离散和连续分布、汇总统计和假设检验。JAX 目前在 jax.scipy.stats
中包装了其中的一些函数,主要包括 20 个左右的统计分布,以及其他一些函数(mode
、rankdata
、gaussian_kde
)。总的来说,这些与 JAX 非常一致:分布通常可以用高效的 XLA 操作来表示,并且 API 是清晰且功能性的。
我们目前没有任何用于假设检验函数的包装器,可能是因为这些函数对 JAX 的主要用户群不太有用。
关于分布,在某些情况下,tensorflow_probability
提供了类似的功能,未来我们可能会考虑是否弃用 scipy.stats 分布,转而使用该实现。
建议:展望未来,我们应该将统计分布和汇总统计视为范围内的,并将假设检验和相关功能通常视为范围外的。