jax.extend:扩展模块#

@froystig@sharadmv@jakevdp@yashk2810

2023年5月

import jax.extend as jex

一些项目依赖于 JAX 的代码库内部,通常是为了使用其核心机制(例如,在其 IR 上编写转换)或扩展它(例如,定义新的原语)。这些依赖项面临的两个挑战是 (a) 我们的内部组件并非都稳固地设计用于外部使用,以及 (b) 绕过 JAX 的公共 API 是不受支持的。换句话说,我们的内部组件通常像库一样使用,但其结构和更新方式却并非如此。

本提案考虑 **引入一个 jax.extend 模块,该模块定义了 JAX 的一些内部组件的库视图**。我们将将其视为二级 API,仍然基本上保证无兼容性策略,但希望在发生更改时更容易发现这些更改。

jax.extend 的目标受众包括 JAX 相关的 Python 库,如 Oryxjax-triton 等,以及尝试函数转换、自动微分系统、数值编程编译器前端等的项目。

本说明概述了 jax.extend 的外观,现在和未来。它没有详细说明,而是建议我们开始迭代开发该模块。

请注意,jax.extendjax.experimental不同,后者是用于存放正在开发的新功能和想法的暂存地。通常,jax.experimental中的工作最终会进入另一个JAX模块或完全移除。

无兼容性策略#

为了降低开发开销,jax.extend不会遵循公开的API兼容性策略。它不会承诺任何弃用窗口或版本之间的向后兼容性。每个版本都可能破坏现有的调用者,而没有简单的补救措施(例如,没有重新引入先前行为的标志)。我们将依靠变更日志来指明此类更改。

需要定期随着JAX版本升级其代码的jax.extend的调用者可能会发现,在版本之间将JAX版本固定作为中间步骤很有用。这在今天依赖JAX内部机制的项目中是一个常见的习惯。不同之处在于,它现在将借助变更日志公告和关于库设计和命名的更好意图。

迭代开发#

没有兼容性策略使得更容易开始实施:在第一天,我们可以从内部包(如jax._src)以及今天的jax.corejax.interpreters中迁移少量符号。然后我们可以迭代改进。

可能的模块概述#

我们可以想象,最终jax.extend将包含以下模块

  • core – 原语、Jaxpr IR等。

  • interpreters – 核心转换(例如自动微分、批处理)和降低。

  • random – 随机比特生成、密钥拆分和折叠、密钥数组。

  • sharding – 分布式数组周围的额外功能。

我们也可能一开始在模块中包含其他符号,例如jex.api_util,因为我们正在努力移除或替换它们。其他的将在以后决定。例如,jex.lib可以提供jaxlib的入口点(并且会在短期内这样做),但目前尚不清楚我们是否希望长期保留它。

以下是一些关于每个模块可能包含内容的初步想法。

jax.extend.core#

这应该至少使调用者能够定义新的JAX原语并处理Jaxpr IR(jax.make_jaxpr(...)的输出)。支持这一点可能涉及提供

  • 对现有核心系统原语的访问,例如今天的jax._src.lax.add_p

  • 对IR类型的访问,例如当前的jax._src.core.ShapedArray

  • 用于检查和美化jaxprs的函数。

  • 用于显式构建jaxprs的函数,而不是通过jax.make_jaxpr(或不通过)对Python函数进行分阶段处理。

在初始化时,此模块将包含比定义原语和规则所需的更多符号,包括在设置“最终样式转换”时使用的各种名称,例如当前的jax._src.core.TraceTracer类。我们可以重新考虑jex.core是否也应该支持最终样式扩展以及初始样式方法,以及它是否可以通过比完全公开TraceTracer更窄的API来做到这一点。Oryx可能会帮助指导这些决策。

我们还可以考虑将make_jaxpr本身重新定位到jex.core

jax.extend.interpreters#

此模块将提供一种方法来注册原语的各种转换规则——定义它们在AD、批处理、降低等下的行为。

它最初会反映jax._src.interpreters,提供adbatchingpartial_eval(用于将Python分阶段处理为Jaxpr,以及用于AD中的线性化)、mlirpxlaxla模块。前三个可能可以用jex.core中的单个原语扩展API替换。后三个用于降低,可能可以简化为一个模块。

今天,要编写转换规则,例如AD和批处理,调用者可能需要与跟踪器相关的符号,例如JVPTracerBatchTracer。这以后可能会避免,并允许我们从jex中删除跟踪器类型。

此模块加上jex.core应该足以复制今天的自定义原语教程(例如我们的dfm的)。例如,定义一个原语及其在jax.jit下的行为将是可能的(在短期内)

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

此模块可以公开我们用于定义新的RNG实现的机制,以及用于处理PRNG密钥内部机制的函数(请参阅问题#9263),例如当前的jax._src.prng.random_wraprandom_unwrap

它还可以公开构成内置RNG实现基础的密钥散列函数,例如jax._src.prng.threefry_2x32

jax.extend.sharding#

此模块可以公开用于对分布式数组进行分片的低级实用程序。

我们目前只有一个想法。XLA编译器的数组分片格式比JAX提供的那些更具表现力。我们可以将其提供为jex.sharding.XlaOpShardingProto,对应于今天内部的jax._src.lib.xla_client.OpSharding