jax.extend:扩展模块#

@froystig, @sharadmv, @jakevdp, @yashk2810

2023 年 5 月

import jax.extend as jex

一些项目依赖 JAX 的代码库内部,通常是为了使用其核心机制(例如,编写一个在其 IR 上的转换)或扩展它(例如,定义新的原语)。这些依赖的两个挑战是(a)我们的内部结构并非都为外部使用而稳固设计,以及(b)规避 JAX 的公共 API 是不支持的。换句话说,我们的内部结构通常像库一样被使用,但既没有像库那样进行结构化,也没有像库那样进行更新。

本提案考虑引入一个 jax.extend 模块,该模块定义 JAX 的一些内部组件的库视图。我们会将其视为第二层 API,仍然保证基本上没有兼容性策略,但希望使其更容易在发生更改时发现。

jax.extend 的受众包括 JAX 相邻的 Python 库,如 Oryxjax-triton 和许多其他库,以及试验函数转换、自动微分系统、数值编程的编译器前端等项目。

本说明概述了 jax.extend 现在和最终可能的样子。它没有详细阐述,而是建议我们开始迭代开发该模块。

请注意,jax.extendjax.experimental 不同,后者是正在进行的新特性和想法的试验场。通常,jax.experimental 中的工作最终会进入另一个 JAX 模块,或者被完全删除。

无兼容性策略#

为了保持较低的开发开销,jax.extend 不会遵循公共的 API 兼容性策略。它不会承诺弃用窗口,也不会承诺版本之间的向后兼容性。每个版本都可能在没有简单补救措施的情况下(例如,没有重新引入先前行为的标志)破坏现有的调用者。我们将依靠 更新日志来指出此类更改。

需要随着 JAX 版本定期升级其代码的 jax.extend 调用者可能会发现,在版本之间将 JAX 版本固定为一个中间步骤很有用。这是目前依赖 JAX 内部结构的项目的常见习惯。不同之处在于,它现在会借助更新日志公告以及在库设计和命名方面更好的意图。

迭代开发#

没有兼容性策略使得开始实现更容易:在第一天,我们可以从内部包(如 jax._src)和今天的 jax.corejax.interpreters 中移动一些符号。然后我们可以迭代地改进这些东西。

可能的模块概述#

我们可以想象,最终 jax.extend 将包括以下模块

  • core – 原语、Jaxpr IR 等。

  • interpreters – 核心转换(例如,自动微分、批处理)和降级。

  • random – 随机位生成、密钥拆分和折叠、密钥数组。

  • sharding – 围绕分布式数组的额外功能。

我们最初可能还在模块中包含其他符号,例如 jex.api_util,因为我们正在努力删除或替换它们。其他内容将及时决定。例如,jex.lib 可以提供 jaxlib 的入口点(并且会在短期内这样做),但目前尚不清楚我们是否要长期保留它。

以下是关于这些模块各自可能包含的内容的一些初步想法。

jax.extend.core#

这应该至少使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...) 的输出)。支持这一点可能需要提供

  • 访问现有的核心系统原语,例如今天的 jax._src.lax.add_p

  • 访问 IR 类型,例如当前的 jax._src.core.ShapedArray

  • 用于检查和漂亮打印 jaxpr 的函数。

  • 用于显式构建 jaxpr 的函数,而不是通过 jax.make_jaxpr (或不!)暂存 Python 函数。

在初始化时,此模块将包含比定义原语和规则所需更多的符号,包括在设置 “最终样式转换”时使用的各种名称,例如当前的 jax._src.core.TraceTracer 类。我们可以重新审视 jex.core 是否也应该支持最终样式扩展以及初始样式方法,以及它是否可以通过比完全暴露 TraceTracer 更窄的 API 来做到这一点。Oryx 可能有助于指导这些决策。

我们还可以考虑将 make_jaxpr 本身重新定位到 jex.core

jax.extend.interpreters#

此模块将提供一种注册原语的各种转换规则的方法——定义它们在 AD、批处理、降级等下的行为。

它最初将反映 jax._src.interpreters,提供模块 adbatchingpartial_eval(用于将 Python 暂存为 Jaxpr 以及用于 AD 中的线性化)、mlirpxlaxla。前三个可能会被 jex.core 中的单个原语扩展 API 所取代。后三个用于降级,可以简化为一个模块,也许。

今天,要编写转换规则,例如用于 AD 和批处理,调用者可能需要与跟踪器相关的符号,例如 JVPTracerBatchTracer。这在以后可能是可以避免的,并允许我们从 jex 中删除跟踪器类型。

此模块加上 jex.core 应该足以复制今天的自定义原语教程(例如 我们的教程dfm 的教程)。例如,定义一个原语及其在 jax.jit 下的行为,在短期内可以按如下方式进行

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

此模块可以公开我们用于定义新 RNG 实现的机制,以及用于处理 PRNG 密钥内部结构的函数(请参阅问题 #9263),例如当前的 jax._src.prng.random_wraprandom_unwrap

它还可以公开内置 RNG 实现的基础的密钥哈希函数,例如 jax._src.prng.threefry_2x32

jax.extend.sharding#

此模块可以公开用于分片分布式数组的底层实用程序。

我们现在只有一个想法。XLA 编译器的数组分片格式比 JAX 提供的格式更具表现力。我们可以将其作为 jex.sharding.XlaOpShardingProto 提供,对应于内部的今天的 jax._src.lib.xla_client.OpSharding