JAX 类型标注路线图#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可选的函数标注 (PEP 3107),后来在 Python 3.5 左右的版本中被正式用于静态类型检查 (PEP 484)。在一定程度上,类型标注和静态类型检查已成为许多 Python 开发工作流程中不可或缺的一部分,为此,我们在 JAX API 的许多地方添加了标注。JAX 中类型标注的现状有点像拼凑而成,并且添加更多标注的努力因更基本的设计问题而受阻。本文档试图总结这些问题并为 JAX 中类型标注的目标和非目标生成路线图。

为什么我们需要这样的路线图? 更好的/更全面的类型注释是用户(内部和外部)的常见要求。 此外,我们经常收到来自外部用户的拉取请求(例如,PR #9917PR #10322),寻求改进 JAX 的类型注释:JAX 团队成员审查代码时并不总是清楚此类贡献是否有益,特别是当它们引入复杂的协议来解决 JAX 使用 Python 进行全面注释固有挑战时。 本文档详细介绍了 JAX 在包中进行类型注释的目标和建议。

为什么进行类型注释?#

Python 项目可能希望注释其代码库的原因有很多;我们将在本文件中将其总结为第一级、第二级和第三级。

第一级:注释作为文档#

在最初引入 PEP 3107 时,类型注释的部分动机是能够将它们用作函数参数类型和返回类型的简洁、内联文档。 JAX 长期以来一直以这种方式使用注释;一个例子是创建类型名称的常见模式,这些类型名称与 Any 相关联。 一个例子可以在 lax/slicing.py 中找到 [源代码]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

为了进行静态类型检查,这种对 Array = Any 进行数组类型注释的使用不会对参数值施加任何限制(Any 等同于没有注释),但它确实为开发人员提供了一种有用的代码内文档形式。

为了生成文档,别名名称会丢失(jax.lax.sliceHTML 文档 报告操作数类型为 Any),因此文档优势不会超出源代码(尽管我们可以启用一些 sphinx-autodoc 选项来改进这一点:参见 autodoc_type_aliases)。

这种级别类型注释的一个好处是,用 Any 注释值永远不会出错,因此它将以文档的形式为开发人员和用户提供具体的好处,而不会增加满足任何特定静态类型检查器更严格需求的复杂性。

第二级:用于智能自动完成的注释#

许多现代 IDE 利用类型注释作为 智能代码完成 系统的输入。 一个例子是 VSCode 的 Pylance 扩展,它使用微软的 pyright 静态类型检查器作为 VSCode 的 IntelliSense 完成信息的来源。

这种类型检查的使用要求比上面使用的简单别名更进一步;例如,知道 slice 函数返回名为 ArrayAny 的别名不会向代码完成引擎添加任何有用的信息。 但是,如果我们用 DeviceArray 返回类型注释该函数,自动完成将知道如何填充结果的命名空间,因此能够在开发过程中建议更相关的自动完成。

JAX 已经开始在少数地方添加这种级别的类型注释;一个例子是 jnp.ndarray 返回类型在 jax.random 包中 [源代码]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray 是一个抽象基类,它前向声明 JAX 数组的属性和方法(参见源代码),因此 VSCode 中的 Pylance 可以提供该函数结果的完整自动完成集。 以下是一个显示结果的屏幕截图

VSCode Intellisense Screenshot

自动完成字段中列出了抽象 ndarray 类声明的所有方法和属性。 我们将在下面进一步讨论为什么必须创建这个抽象类,而不是直接用 DeviceArray 注释。

第三级:用于静态类型检查的注释#

如今,静态类型检查通常是人们在考虑 Python 代码中类型注释目的时首先想到的。 虽然 Python 不会对类型进行任何运行时检查,但存在几种成熟的静态类型检查工具,可以将此作为 CI 测试套件的一部分来完成。 对 JAX 来说最重要的工具如下

  • python/mypy 或多或少是开源 Python 世界中的标准。 JAX 目前在 Github Actions CI 检查中对源文件的一个子集运行 mypy。

  • google/pytype 是 Google 的静态类型检查器,依赖于 JAX 的 Google 内部项目经常使用它。

  • microsoft/pyright 非常重要,因为它是在 VSCode 中为 Pylance 完成使用的静态类型检查器。

完整的静态类型检查是所有类型注释应用中最严格的,因为它会在你的类型注释不完全正确时出现错误。 一方面,这很好,因为你的静态类型分析可能会发现错误的类型注释(例如,DeviceArray 方法在 jnp.ndarray 抽象类中缺失的情况)。

另一方面,这种严格性可能会使类型检查过程在经常依赖鸭子类型而不是严格类型安全 API 的包中变得非常脆弱。 你目前会发现像 #type: ignore(对于 mypy)或 #pytype: disable(对于 pytype)这样的代码注释散布在整个 JAX 代码库中的数百个地方。 这些通常代表出现类型问题的情况;它们可能是 JAX 类型注释中的不准确,也可能是静态类型检查器正确跟踪代码中控制流的能力不准确。 有时,它们是由于 pytype 或 mypy 行为中的真实且细微的错误造成的。 在极少数情况下,它们可能是由于 JAX 使用了难以用 Python 的静态类型注释语法表达的 Python 模式,甚至无法表达造成的。

JAX 的类型注释挑战#

JAX 目前拥有混合了不同风格的类型注释,并且针对上面讨论的所有三个级别的类型注释。 部分原因是 JAX 的源代码对 Python 的类型注释系统提出了许多独特的挑战。 我们将在下面概述它们。

挑战 1:pytype、mypy 和开发人员摩擦#

JAX 目前面临的一个挑战是,包开发必须满足两个不同的静态类型检查系统 pytype(由内部 CI 和 Google 内部项目使用)和 mypy(由外部 CI 和外部依赖项使用)的约束。 尽管这两个类型检查器在行为上有很大的重叠,但每个检查器都呈现出自己独特的边缘情况,这一点从整个 JAX 代码库中众多 #type: ignore#pytype: disable 语句中可以看出来。

这在开发中造成了摩擦:内部贡献者可能会进行迭代,直到测试通过,但结果却发现,在导出时,他们经过 pytype 批准的代码违反了 mypy。 对于外部贡献者来说,情况往往相反:最近的一个例子是 #9596,它在通过 Google 内部 pytype 检查后不得不回滚。 每当我们将类型注释从第一级(Any 无处不在)移动到第二级或第三级(更严格的注释)时,就会引入更多这种令人沮丧的开发人员体验的可能性。

挑战 2:数组鸭子类型#

为 JAX 代码添加注释的一个特殊挑战是它对鸭子类型的广泛使用。 一般来说,标记为 Array 的函数的输入可以是许多不同的类型之一:JAX DeviceArray、NumPy np.ndarray、NumPy 标量、Python 标量、Python 序列、具有 __array__ 属性的对象、具有 __jax_array__ 属性的对象,或任何形式的 jax.Tracer。 因此,像 def func(x: DeviceArray) 这样的简单注释将不足以满足需求,并且会导致许多有效使用出现误报。 这意味着 JAX 函数的类型注释不会简短或简单,但我们必须有效地开发出一套 JAX 特定的类型扩展,类似于 numpy.typing 中的类型扩展。

挑战 3:转换和装饰器#

JAX 的 Python API 很大程度上依赖于函数变换(jit()vmap()grad() 等),这种类型的 API 对静态类型分析提出了特殊的挑战。装饰器的灵活注释一直是 mypy 包中的一个 长期存在的问题,直到最近才通过引入 ParamSpec 解决,该问题在 PEP 612 中进行了讨论,并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,因此它不能依赖 Python 3.10 的功能,直到 2024 年年中之后。在此期间,协议可以作为此问题的部分解决方案(JAX 在 #9950 中为 jit 和其他方法添加了此功能),并且可以通过 typing_extensions 包使用 ParamSpec(#9999 中有一个原型),尽管这目前暴露了 mypy 中的基本错误(请参阅 python/mypy#12593)。总而言之:尚不清楚 JAX 函数变换的 API 是否可以在当前 Python 类型注释工具的约束范围内得到适当的注释。

挑战 4:数组注释缺乏粒度#

另一个挑战是所有面向数组的 Python API 都面临的共同问题,并且多年来一直是 JAX 讨论的一部分(请参阅 #943)。类型注释与 Python 类或对象的类型有关,而在基于数组的语言中,类的属性往往更重要。在 NumPy、JAX 和类似包的情况下,我们通常希望注释特定的数组形状和数据类型。

例如,jnp.linspace 函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了防止注释出现误报,我们必须允许这些参数为任意数组。另一个例子是 jax.random.choice 的第二个参数,当 shape=() 时,它必须具有 dtype=int。Python 计划通过可变参数类型泛型(请参阅 PEP 646,计划在 Python 3.11 中发布)启用具有这种粒度的类型注释,但与 ParamSpec 一样,对该功能的支持需要一段时间才能稳定下来。

在此期间,有一些第三方项目可能会有所帮助,特别是 google/jaxtyping,但它使用非标准注释,可能不适合注释 JAX 库本身。总而言之,数组类型粒度挑战不如其他挑战那么严重,因为其主要影响是数组类注释将比原本可能的情况更不具体。

挑战 5:从 NumPy 继承的精度较低的 API#

JAX 的很大一部分面向用户的 API 是从 NumPy 中继承的,位于 jax.numpy 子模块中。NumPy 的 API 是在静态类型检查成为 Python 语言的一部分之前很早开发的,并且遵循 Python 的历史建议,使用 鸭子类型/EAFP 编码风格,其中不鼓励在运行时进行严格的类型检查。作为此的具体示例,请考虑 numpy.tile() 函数,其定义如下

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

此处,意图reps 将包含一个 int 或一个 int 值序列,但实现允许 tup 为任何可迭代对象。在为这种鸭子类型代码添加注释时,我们可以选择两种方式之一

  1. 我们可以选择注释函数 API 的意图,此处可能是类似 reps: Union[int, Sequence[int]] 的内容。

  2. 相反,我们可以选择注释函数的实现,此处可能看起来像 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中 ConvertibleToInt 是一个特殊协议,涵盖了函数将输入转换为整数的确切机制(即通过 __int__,通过 __index__,通过 __array__ 等)。还要注意,严格来说,Iterable 在这里不够,因为 Python 中有一些对象是鸭子类型,但无法通过 Iterable 进行静态类型检查(即,通过 __getitem__ 而不是 __iter__ 进行迭代的对象)。

方法 #1 的优点,即注释意图,是注释在传达 API 合同方面对用户更有用;而对于开发人员来说,这种灵活性在必要时为重构留下了空间。缺点(特别是对于像 JAX 的逐步类型化的 API)是,用户代码很可能存在,这些代码可以正常运行,但类型检查器会将其标记为不正确。逐步类型化现有鸭子类型 API 意味着当前注释隐式地为 Any,因此将其更改为更严格的类型可能会对用户来说是一种重大更改。

总的来说,注释意图更适合 Level 1 类型检查,而注释实现更适合 Level 3,Level 2 则更像是混合包(在 IDE 中进行注释时,意图和实现都很重要)。

JAX 类型注释路线图#

牢记这种框架(Level 1/2/3)和 JAX 特定的挑战,我们可以开始制定在整个 JAX 项目中实施一致类型注释的路线图。

指导原则#

对于 JAX 类型注释,我们将遵循以下原则

类型注释的用途#

我们希望尽可能支持完整的Level 1、2 和 3 类型注释。特别是,这意味着我们应该在公共 API 函数的输入和输出上都具有限制性的类型注释。

注释意图#

JAX 类型注释通常应表示 API 的**意图**,而不是实现,这样注释就成为传达 API 合同的有用方法。这意味着有时在运行时有效的输入可能无法被静态类型检查器识别为有效(一个例子可能是传递给注释为 Shape = Sequence[int] 的形状的任意迭代器)。

输入应具有允许类型#

JAX 函数和方法的输入应尽可能具有允许类型:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。类似地,接受 dtype 的函数不需要要求 np.dtype 类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,例如 np.float64jnp.float64。为了使这在整个包中尽可能统一,我们将添加一个 jax.typing 模块,其中包含常见的类型规范,从广泛的类别开始,例如

  • ArrayLike 将是任何可以隐式转换为数组的组合:例如,jax 数组、numpy 数组、JAX 跟踪器以及 python 或 numpy 标量

  • DTypeLike 将是任何可以隐式转换为 dtype 的组合:例如,numpy dtype、numpy dtype 对象、jax dtype 对象、字符串以及内置类型。

  • ShapeLike 将是任何可以转换为形状的组合:例如,整数或类似整数的对象的序列。

  • 等等。

请注意,这些通常比 numpy.typing 中使用的等效协议更简单。例如,在 DTypeLike 的情况下,JAX 不支持结构化 dtype,因此 JAX 可以使用更简单的实现。类似地,在 ArrayLike 中,JAX 通常不支持将列表或元组输入作为数组,因此类型定义将比 NumPy 的类似物更简单。

输出应具有严格类型#

相反,函数和方法的输出应尽可能具有严格类型:例如,对于返回数组的 JAX 函数,输出应使用类似 jnp.ndarray 的内容进行注释,而不是 ArrayLike。返回 dtype 的函数应始终注释为 np.dtype,返回形状的函数应始终为 Tuple[int] 或严格类型化的 NamedShape 等效项。为此,我们将在 jax.typing 中实现上述允许类型的几个严格类型化的类似物,即

  • ArrayNDArray(见下文)用于类型标注目的,实际上等效于 Union[Tracer, jnp.ndarray],应该用于标注数组输出。

  • DTypenp.dtype 的别名,可能还可以表示 JAX 内部使用的关键类型和其他泛化。

  • Shape 本质上是 Tuple[int, ...],可能还具有一些额外的灵活性以适应动态形状。

  • NamedShapeShape 的扩展,它允许使用 JAX 内部使用的命名形状。

  • 等等。

我们还将探讨是否可以放弃当前 jax.numpy.ndarray 的实现,转而将 ndarray 设为 Array 或类似的别名。

追求简洁#

除了收集在 jax.typing 中的常见类型协议之外,我们应该尽可能追求简洁。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而是在 API 的完整类型规范无法简洁地指定的情况下,使用简单的联合,例如 Union[simple_type, Any]。这是一种折衷方案,它实现了级别 1 和 2 标注的目标,同时放弃了级别 3,以避免不必要的复杂性。

避免不稳定的类型机制#

为了避免增加不必要的开发摩擦(由于内部/外部 CI 的差异),我们希望在使用的类型标注结构方面保持保守:特别是对于最近引入的机制,例如 ParamSpec (PEP 612) 和可变参数类型泛型 (PEP 646),我们希望等待 mypy 和其他工具的支持成熟稳定,然后再依赖它们。

这方面的一个影响是,目前,当函数被 JAX 变换(如 jitvmapgrad 等)装饰时,JAX 将有效地 **剥离装饰函数的所有标注**。虽然这令人遗憾,但在撰写本文时,mypy 与 ParamSpec 提供的潜在解决方案存在大量不兼容问题(参见 ParamSpec mypy 错误跟踪器),因此我们认为它目前还没有准备好完全应用于 JAX。我们将在未来重新审视这个问题,届时此类功能的支持将更加稳定。

同样,目前,我们将避免添加 jaxtyping 项目提供的更复杂和粒度更细的数组类型标注。这是一个我们可以在未来某个时间重新审视的决定。

Array 类型设计注意事项#

如上所述,在 JAX 中对数组进行类型标注是一个独特的挑战,因为 JAX 广泛使用鸭子类型,即在 jax 变换中用 Tracer 对象代替实际数组进行传递和返回。这变得越来越令人困惑,因为用于类型标注的对象经常与用于运行时实例检查的对象重叠,并且可能与所讨论对象的实际类型层次结构相对应,也可能不对应。对于 JAX,我们需要提供鸭子类型对象,用于两个上下文:**静态类型标注** 和 **运行时实例检查**。

以下讨论将假设 jax.Array 是设备上数组的运行时类型,这还没有实现,但一旦 #12016 中的工作完成,就会实现。

静态类型标注#

我们需要提供一个可以用于鸭子类型标注的对象。假设暂时将此对象称为 ArrayAnnotation,我们需要一个满足 mypypytype 的解决方案,以解决以下情况

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法实现,例如

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,声明 TracerArray 应该被视为 ArrayAnnotation 的子类。

  • 重构 ArrayTracer,使 ArrayAnnotation 成为两者的真正基类。

运行时实例检查#

我们还必须提供一个可以用于鸭子类型运行时 isinstance 检查的对象。假设暂时将此对象称为 ArrayInstance,我们需要一个通过以下运行时检查的解决方案

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同样,可以使用几种机制来实现这一点

  • 覆盖 type(ArrayInstance).__instancecheck__ 以对 ArrayTracer 对象返回 True;这就是 jnp.ndarray 的当前实现方式 (源代码)。

  • ArrayInstance 定义为抽象基类,并将其动态注册到 ArrayTracer

  • 重构 ArrayTracer,使 ArrayInstance 成为 ArrayTracer 的真正基类

我们需要做出一个决定,即 ArrayAnnotationArrayInstance 应该是同一个对象还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dicttyping.List 为了标注目的而存在,而内置的 dictlist 用于实例检查。但是,DictList 在较新的 Python 版本中已被弃用,转而使用 dictlist 来进行标注和实例检查。

效仿 NumPy#

在 NumPy 的情况下,np.typing.NDArray 用于类型标注,而 np.ndarray 用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实现以下内容可能是合理的

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.NDArray 是用于鸭子类型数组标注的对象。

  • jax.numpy.ndarray 是用于鸭子类型数组实例检查的对象。

这对于 NumPy 高级用户来说可能感觉很自然,但是这种三方分离可能会造成混淆:对于实例检查和标注,使用哪一个并不立即明了。

统一实例检查和标注#

另一种方法是通过上面提到的覆盖机制统一类型检查和标注。

选项 1:部分统一#

部分统一可能看起来像这样

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.Array 是用于鸭子类型数组标注的对象(通过 .pyi 接口在 ArrayTracer 上)。

  • jax.typing.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 覆盖)。

在这种方法中,jax.numpy.ndarray 将成为 jax.typing.Array 的一个简单别名,以实现向后兼容性。

选项 2:通过覆盖实现完全统一#

或者,我们可以选择通过覆盖实现完全统一

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于鸭子类型数组标注的对象(通过 .pyi 接口在 Tracer 上)。

  • jax.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 覆盖)。

这里,jax.numpy.ndarray 将成为 jax.Array 的一个简单别名,以实现向后兼容性。

选项 3:通过类层次结构实现完全统一#

最后,我们可以选择通过重构类层次结构并将鸭子类型替换为 OOP 对象层次结构来实现完全统一。

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于数组类型注释的对象,通过确保 Tracer 继承自 jax.Array

  • jax.Array 也是通过相同机制用于实例检查的对象。

这里 jnp.ndarray 可以是 jax.Array 的别名。这种最终方法在某种意义上是最纯粹的,但从 OOP 设计的角度来看,它有点强迫性 (Tracer Array?)。

选项 4:通过类层次结构实现部分统一#

我们可以通过让 Tracer 和设备上数组的类继承自一个公共基类来使类层次结构更加合理。例如

  • jax.ArrayTracer 以及设备上数组实际类型的基类,该类型可能是 jax._src.ArrayImpl 或类似类型。

  • jax.Array 是用于数组类型注释的对象。

  • jax.Array 也是用于实例检查的对象。

这里 jnp.ndarray 将是 Array 的别名。从 OOP 的角度来看,这可能更加纯粹,但与选项 2 和 3 相比,它放弃了 type(x) is jax.Array 将评估为 True 的概念。

评估#

考虑每种潜在方法的整体优势和劣势。

  • 从用户的角度来看,统一的方法(选项 2 和 3)可以说是最好的,因为它们消除了记住哪些对象用于实例检查或注释的认知负担:jax.Array 是你需要知道的全部。

  • 但是,选项 2 和 3 都引入了某些奇怪或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查覆盖,这些覆盖 在 pybind11 中定义的类中不受良好支持。选项 3 要求 Tracer 是数组的子类。这打破了继承模型,因为它将要求 Tracer 对象携带 Array 对象的所有负担(数据缓冲区、分片、设备等)。

  • 选项 4 在 OOP 意义上更加纯粹,并且避免了对典型实例检查或类型注释行为的任何覆盖。权衡是设备上数组的实际类型变成了某些独立的东西(这里为 jax._src.ArrayImpl)。但绝大多数用户永远不需要直接接触这个私有实现。

这里有不同的权衡,但在讨论后,我们最终选择了选项 4 作为我们的前进方向。

实施计划#

为了推进类型注释,我们将执行以下操作

  1. 迭代此 JEP 文档,直到开发人员和利益相关者都认可。

  2. 创建一个私有 jax._src.typing(目前不提供任何公共 API)并在其中放置上面提到的第一层简单类型。

    • 目前将 Array = Any 作为别名,因为这需要更仔细的思考。

    • ArrayLike:作为正常 jax.numpy 函数的输入有效的类型联合。

    • DType / DTypeLike(注意:numpy 使用驼峰式大小写 DType;为了便于使用,我们应该遵循此约定)。

    • Shape / NamedShape / ShapeLike

    这些的开始在 #12300 中完成。

  3. 开始构建一个遵循上一节中选项 4 的 jax.Array 基类。最初,这将在 Python 中定义,并使用当前在 jnp.ndarray 实现中找到的动态注册机制来确保 isinstance 检查的正确行为。每个跟踪器和类似数组类的 pyi 覆盖将确保类型注释的正确行为。然后,jnp.ndarray 可以被设置为 jax.Array 的别名。

  4. 作为测试,使用这些新的类型定义根据上述指南全面注释 jax.lax 中的函数。

  5. 继续以一个模块为单位添加其他注释,重点关注公共 API 函数。

  6. 同时,开始在 pybind11 中重新实现一个 jax.Array 基类,以便 ArrayImplTracer 可以从它继承。使用 pyi 定义来确保静态类型检查器识别类的适当属性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全落地,删除这些临时的 Python 实现。

  8. 完成所有操作后,创建一个公共 jax.typing 模块,使上述类型可供用户使用,以及使用 JAX 的代码的注释最佳实践文档。

我们将在 #12049 中跟踪这项工作,从该工作中获得此 JEP 的编号。