jax.extend:一个用于扩展的模块#

@froystig, @sharadmv, @jakevdp, @yashk2810

2023 年 5 月

import jax.extend as jex

有几个项目依赖于 JAX 的代码库内部,通常是为了使用其核心机制(例如,在其 IR 上编写转换)或扩展它(例如,定义新的原语)。这些依赖关系的两个挑战是 (a) 我们的内部组件并非都为外部使用而设计,以及 (b) 规避 JAX 的公共 API 是不支持的。换句话说,我们的内部组件通常像库一样使用,但它们的结构和更新方式都不像库。

本提案考虑引入一个 jax.extend 模块,该模块定义了 JAX 的一些内部组件的库视图。我们会将其视为第二层 API,仍然基本保证没有兼容性策略,但希望在发生更改时更容易发现。

jax.extend 的受众包括 JAX 相关的 Python 库,如 Oryxjax-triton 等等,以及试验函数转换、自动微分系统、数值编程编译器前端等的项目。

本说明概述了 jax.extend 的可能外观,包括现在和最终的样子。它没有详细展开,而是建议我们开始对该模块进行迭代开发

请注意,jax.extendjax.experimental 不同,后者是正在进行中的新特性和想法的试验场。通常,jax.experimental 中的工作最终会进入另一个 JAX 模块或被完全移除。

无兼容性策略#

为了保持较低的开发开销,jax.extend 将不遵循公开的 API 兼容性策略。它不承诺弃用窗口或版本之间的向后兼容性。每个版本都可能破坏现有的调用者,而没有简单的补救措施(例如,没有一个标志可以重新引入先前的行为)。我们将依靠变更日志来指出这些更改。

需要随着 JAX 版本定期升级代码的 jax.extend 调用者可能会发现,在版本之间将 JAX 版本固定下来作为中间步骤很有用。这是当今依赖 JAX 内部项目的常见习惯。不同之处在于,现在它将伴随着变更日志公告以及关于库设计和命名的更好意图。

迭代开发#

没有兼容性策略使得开始实施更容易:在第一天,我们可以从内部包(如 jax._src)和今天的 jax.core 以及 jax.interpreters 中移动一些符号。然后我们可以迭代改进。

可能的模块概述#

我们可以想象,最终 jax.extend 将包括以下模块

  • core – 原语、Jaxpr IR 等。

  • interpreters – 核心转换(例如,自动微分、批处理)和降级。

  • random – 随机位生成、密钥拆分和折叠、密钥数组。

  • sharding – 关于分布式数组的额外功能。

最初,我们可能在该模块中还有其他符号,例如 jex.api_util,因为我们致力于删除或替换它们。其他符号将在未来决定。例如,jex.lib 可以提供 jaxlib 的入口点(并且在短期内会这样做),但我们是否要长期保留它尚不清楚。

下面是一些关于这些模块可能包含内容的初步想法。

jax.extend.core#

这应该使调用者至少能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...) 的输出)。支持此功能可能涉及提供

  • 访问现有的核心系统原语,例如今天的 jax._src.lax.add_p

  • 访问 IR 类型,例如当前的 jax._src.core.ShapedArray

  • 用于检查和漂亮打印 jaxpr 的函数。

  • 用于显式构建 jaxpr 的函数,而不是通过 jax.make_jaxpr 暂存 Python 函数(或不暂存!)。

在初始化时,该模块将包含比定义原语和规则所需的更多的符号,包括在设置“最终样式转换”中使用的各种名称,例如当前的 jax._src.core.TraceTracer 类。我们可以重新考虑 jex.core 是否也应该支持最终样式扩展以及初始样式方法,以及它是否可以通过比完全公开 TraceTracer 更窄的 API 来做到这一点。Oryx 可能会帮助指导这些决策。

我们还可以考虑将 make_jaxpr 本身重定位到 jex.core

jax.extend.interpreters#

该模块将提供一种为原语注册各种转换规则的方法——定义它们在 AD、批处理、降级等下的行为。

它最初将反映 jax._src.interpreters,提供模块 adbatchingpartial_eval (用于将 Python 暂存到 Jaxpr,以及 AD 中的线性化)、mlirpxlaxla。前三个可能会被 jex.core 中的单个原语扩展 API 替换。后三个用于降级的模块,可以简化为一个模块,也许吧。

今天,要编写转换规则,例如用于 AD 和批处理,调用者可能需要与跟踪器相关的符号,例如 JVPTracerBatchTracer。这可能在以后可以避免,并允许我们从 jex 中删除跟踪器类型。

该模块加上 jex.core 应该足以复制今天的自定义原语教程(例如,我们的教程dfm 的教程)。例如,定义一个原语及其在 jax.jit 下的行为是可能的,如下所示(在短期内)

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

该模块可以公开我们用于定义新的 RNG 实现的机制,以及用于处理 PRNG 密钥内部的函数(参见问题 #9263),例如当前的 jax._src.prng.random_wraprandom_unwrap

它还可以公开构成内置 RNG 实现基础的键控哈希函数,例如 jax._src.prng.threefry_2x32

jax.extend.sharding#

该模块可以公开用于分片分布式数组的低级实用程序。

我们现在只有一个想法。XLA 编译器的数组分片格式比 JAX 提供的那些格式更具表现力。我们可以将其作为 jex.sharding.XlaOpShardingProto 提供,对应于内部的 jax._src.lib.xla_client.OpSharding