JAX 的类型注解路线图#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可选的函数注解(PEP 3107),后来在 Python 3.5 版本发布前后被编纂用于静态类型检查(PEP 484)。在某种程度上,类型注解和静态类型检查已成为许多 Python 开发工作流程不可或缺的一部分,为此,我们在 JAX API 的许多地方添加了注解。目前 JAX 中的类型注解状态有些零散,并且由于更基本的设计问题,添加更多注解的工作受到了阻碍。本文档试图总结这些问题,并为 JAX 中类型注解的目标和非目标生成路线图。

为什么我们需要这样的路线图?更好/更全面的类型注解是用户(包括内部和外部用户)经常提出的要求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917, PR #10322)寻求改进 JAX 的类型注解:对于审查代码的 JAX 团队成员来说,并不总是清楚这些贡献是否有益,尤其是在它们引入复杂的协议来解决 JAX 对 Python 使用进行全面注解时固有的挑战时。本文档详细介绍了 JAX 在包中进行类型注解的目标和建议。

为什么要进行类型注解?#

一个 Python 项目可能希望注解其代码库的原因有很多;我们在本文档中将它们总结为级别 1、级别 2 和级别 3。

级别 1:作为文档的注解#

最初在 PEP 3107 中引入时,类型注解的部分动机是能够将它们用作函数参数类型和返回类型的简洁的内联文档。JAX 长期以来一直以这种方式使用注解;一个例子是创建别名为 Any 的类型名称的常见模式。一个例子可以在 lax/slicing.py 中找到 [源代码]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

为了静态类型检查的目的,这种使用 Array = Any 进行数组类型注解的方式对参数值没有任何约束(Any 等同于根本没有注解),但它确实可以作为开发人员有用的代码内文档的一种形式。

为了生成文档,别名的名称会丢失(HTML 文档中,jax.lax.slice 的操作数报告为类型 Any),因此文档的好处不会超出源代码范围(尽管我们可以启用一些 sphinx-autodoc 选项来改进这一点:请参阅 autodoc_type_aliases)。

这种级别的类型注解的好处是,用 Any 注解一个值永远不会出错,因此它将以文档的形式为开发人员和用户提供具体的好处,而不会增加满足任何特定静态类型检查器更严格需求的复杂性。

级别 2:用于智能自动完成的注解#

许多现代 IDE 利用类型注解作为 智能代码完成系统的输入。这方面的一个例子是 VSCode 的 Pylance 扩展,它使用 Microsoft 的 pyright 静态类型检查器作为 VSCode 的 IntelliSense 完成的信息来源。

这种类型检查的使用需要比上面使用的简单别名更进一步;例如,知道 slice 函数返回一个名为 ArrayAny 别名,并不会为代码完成引擎添加任何有用的信息。但是,如果我们使用 DeviceArray 返回类型注解该函数,则自动完成功能将知道如何填充结果的命名空间,从而能够在开发过程中建议更相关的自动完成。

JAX 已经开始在一些地方添加此级别的类型注解;一个例子是 jax.random 包中的 jnp.ndarray 返回类型 [源代码]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray 是一个抽象基类,它预先声明了 JAX 数组的属性和方法(请参阅源代码),因此 VSCode 中的 Pylance 可以从此函数的结果提供完整的自动完成集合。这是一个显示结果的屏幕截图

VSCode Intellisense Screenshot

自动完成字段中列出了抽象 ndarray 类声明的所有方法和属性。我们将在下面进一步讨论为什么有必要创建这个抽象类,而不是直接使用 DeviceArray 进行注解。

级别 3:用于静态类型检查的注解#

如今,当人们考虑 Python 代码中类型注解的目的时,静态类型检查通常是首先想到的。虽然 Python 不会对类型进行任何运行时检查,但存在一些成熟的静态类型检查工具,可以将其作为 CI 测试套件的一部分进行检查。对 JAX 最重要的工具如下

  • python/mypy 或多或少是开放 Python 世界中的标准。JAX 目前在 Github Actions CI 检查中对源文件的子集运行 mypy。

  • google/pytype 是 Google 的静态类型检查器,Google 内依赖于 JAX 的项目经常使用它。

  • microsoft/pyright 作为 VSCode 中用于前面提到的 Pylance 完成的静态类型检查器,非常重要。

完整的静态类型检查是所有类型注解应用中最严格的,因为它会在您的类型注解不完全正确时显示错误。一方面,这很好,因为您的静态类型分析可能会捕获错误的类型注解(例如,DeviceArray 方法在 jnp.ndarray 抽象类中缺失的情况)。

另一方面,在经常依赖鸭子类型而不是严格的类型安全 API 的包中,这种严格性会使类型检查过程非常脆弱。您目前会在 JAX 代码库中数百个地方发现诸如 #type: ignore(用于 mypy)或 #pytype: disable(用于 pytype)之类的代码注释。这些通常表示出现了类型问题的情况;它们可能是 JAX 类型注解中的不准确之处,也可能是静态类型检查器正确跟踪代码中控制流的能力不准确之处。有时,它们是由于 pytype 或 mypy 行为中真实而微妙的错误造成的。在极少数情况下,它们可能是由于 JAX 使用 Python 模式而造成的,这些模式很难甚至不可能用 Python 的静态类型注解语法来表达。

JAX 的类型注解挑战#

JAX 当前具有混合了不同样式且针对上述所有三个级别的类型注解的类型注解。部分原因是 JAX 的源代码对 Python 的类型注解系统提出了一些独特的挑战。我们将在这里概述它们。

挑战 1:pytype、mypy 和开发人员摩擦#

JAX 当前面临的一个挑战是,包开发必须满足两个不同的静态类型检查系统的约束,即 pytype(由内部 CI 和 Google 内部项目使用)和 mypy(由外部 CI 和外部依赖项使用)。尽管两个类型检查器的行为有广泛的重叠,但正如 JAX 代码库中大量的 #type: ignore#pytype: disable 语句所证明的那样,每个类型检查器都有其独特的角落案例。

这会在开发中产生摩擦:内部贡献者可能会迭代直到测试通过,结果发现他们的 pytype 批准的代码在导出时违反了 mypy。对于外部贡献者,情况通常相反:最近的一个例子是 #9596,它在未能通过 Google 内部 pytype 检查后不得不回滚。每次我们将类型注解从级别 1(到处都是 Any)移动到级别 2 或 3(更严格的注解)时,都会为这种令人沮丧的开发者体验引入更多可能性。

挑战 2:数组鸭子类型#

为 JAX 代码添加注解的一个特殊挑战在于其大量使用鸭子类型。一个被标记为 Array 的函数的输入通常可以是多种不同的类型:一个 JAX DeviceArray、一个 NumPy np.ndarray、一个 NumPy 标量、一个 Python 标量、一个 Python 序列、一个带有 __array__ 属性的对象、一个带有 __jax_array__ 属性的对象,或者任何类型的 jax.Tracer。因此,像 def func(x: DeviceArray) 这样的简单注解是不够的,并且会导致许多有效用例的误报。这意味着 JAX 函数的类型注解不会简短或简单,我们必须有效地开发一套类似于 numpy.typing中的 JAX 特定类型扩展。

挑战 3:转换和装饰器#

JAX 的 Python API 严重依赖函数转换(jit()vmap()grad() 等),这种类型的 API 为静态类型分析带来了特殊的挑战。装饰器的灵活注解一直是 mypy 包中一个长期存在的问题,直到最近才通过引入 ParamSpec 解决,如 PEP 612 中所述,并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,它不能在 2024 年中期之后的某个时间之前依赖 Python 3.10 的功能。同时,Protocols 可以作为部分解决方案(JAX 在 #9950 中为 jit 和其他方法添加了此功能),并且可以通过 typing_extensions 包使用 ParamSpec(一个原型在 #9999 中),但这目前揭示了 mypy 中的基本错误(参见 python/mypy#12593)。总而言之:目前尚不清楚是否可以在当前 Python 类型注解工具的约束范围内适当地注解 JAX 函数转换的 API。

挑战 4:数组注解缺乏粒度#

这里的另一个挑战是 Python 中所有面向数组的 API 普遍存在的,并且多年来一直是 JAX 讨论的一部分(参见 #943)。类型注解与对象的 Python 类或类型有关,而在基于数组的语言中,类的属性通常更重要。在 NumPy、JAX 和类似包的情况下,我们通常希望注解特定的数组形状和数据类型。

例如,jnp.linspace 函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注解不引发误报,我们必须允许这些参数是任意数组。另一个例子是 jax.random.choice 的第二个参数,当 shape=() 时,它必须具有 dtype=int。Python 计划通过可变类型泛型启用具有此粒度级别的类型注解(参见 PEP 646,计划用于 Python 3.11),但与 ParamSpec 一样,对该功能的支持需要一段时间才能稳定下来。

目前有一些第三方项目可能会有所帮助,特别是 google/jaxtyping,但这使用了非标准注解,可能不适合注解核心 JAX 库本身。总的来说,数组类型粒度挑战不如其他挑战那么严重,因为主要的影响是类数组注解将不如它们本应的那样具体。

挑战 5:从 NumPy 继承的不精确 API#

JAX 面向用户的 API 的很大一部分是从 jax.numpy 子模块中的 NumPy 继承的。NumPy 的 API 是在静态类型检查成为 Python 语言的一部分之前开发的,并且遵循 Python 的历史建议,使用 鸭子类型/ EAFP 编码风格,其中不鼓励在运行时进行严格的类型检查。作为此示例的具体例子,请考虑 numpy.tile() 函数,其定义如下

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

这里的意图reps 应包含一个 int 或一个 int 值的序列,但是实现允许 tup 是任何可迭代对象。在向这种鸭子类型的代码添加注解时,我们可以采取以下两种路线之一

  1. 我们可以选择注解函数 API 的意图,这里可能类似于 reps: Union[int, Sequence[int]]

  2. 相反,我们可以选择注解函数的实现,这里可能类似于 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中 ConvertibleToInt 是一个特殊的协议,涵盖了我们的函数将输入转换为整数的确切机制(即通过 __int__、通过 __index__、通过 __array__ 等)。还要注意,从严格意义上讲,这里的 Iterable 并不足够,因为 Python 中有些对象具有鸭子类型,可以作为可迭代对象,但不能满足针对 Iterable 的静态类型检查(即通过 __getitem__ 而不是 __iter__ 可迭代的对象)。

注解意图的 #1 的优点在于,这些注解在传达 API 契约方面对用户更有用;而对于开发人员来说,灵活性为必要时的重构留有余地。缺点(特别是对于像 JAX 这样逐步类型的 API)是,很可能存在运行正确的用户代码,但会被类型检查器标记为不正确。对现有的鸭子类型 API 进行逐步类型化意味着当前的注解隐式为 Any,因此将其更改为更严格的类型可能会向用户显示为重大更改。

从广义上讲,注解意图更好地服务于 1 级类型检查,而注解实现更好地服务于 3 级类型检查,而 2 级则更像是混合体(当涉及到 IDE 中的注解时,意图和实现都很重要)。

JAX 类型注解路线图#

考虑到这个框架(1/2/3 级)和 JAX 特有的挑战,我们可以开始制定在整个 JAX 项目中实现一致类型注解的路线图。

指导原则#

对于 JAX 类型注解,我们将遵循以下原则

类型注解的目的#

我们希望尽可能支持完整的1、2 和 3 级类型注解。特别是,这意味着我们应该对公共 API 函数的输入和输出都进行限制性类型注解。

为意图注解#

JAX 类型注解通常应指示 API 的意图,而不是实现,以便注解有助于传达 API 的契约。这意味着有时运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是传递任意迭代器来代替注解为 Shape = Sequence[int] 的形状)。

输入应允许类型化#

JAX 函数和方法的输入应尽可能允许地类型化:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。类似地,接受 dtype 的函数不需要 np.dtype 类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,例如 np.float64jnp.float64。为了使整个软件包尽可能统一,我们将添加一个 jax.typing 模块,其中包含常见的类型规范,从广泛的类别开始,例如

  • ArrayLike 将是可以隐式转换为数组的任何内容的联合:例如,jax 数组、numpy 数组、JAX 跟踪器以及 python 或 numpy 标量

  • DTypeLike 将是可以隐式转换为 dtype 的任何内容的联合:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。

  • ShapeLike 将是可以转换为形状的任何内容的联合:例如,整数或类整数对象的序列。

  • 等等。

请注意,这些通常比 numpy.typing 中使用的等效协议更简单。例如,在 DTypeLike 的情况下,JAX 不支持结构化 dtypes,因此 JAX 可以使用更简单的实现。类似地,在 ArrayLike 中,JAX 通常不支持使用列表或元组输入来代替数组,因此类型定义将比 NumPy 模拟更简单。

输出应严格类型化#

相反,函数和方法的输出应该尽可能严格地进行类型标注:例如,对于返回数组的 JAX 函数,输出应该使用类似于 jnp.ndarray 的方式进行标注,而不是使用 ArrayLike。返回 dtype 的函数应该始终标注为 np.dtype,而返回 shape 的函数应该始终标注为 Tuple[int] 或严格类型的 NamedShape 等效形式。为此,我们将在 jax.typing 中实现上述宽松类型的几种严格类型模拟,即

  • ArrayNDArray (见下文)用于类型标注目的,实际上等同于 Union[Tracer, jnp.ndarray],应该用于标注数组输出。

  • DTypenp.dtype 的别名,或许还能够表示 JAX 中使用的关键类型和其他泛化。

  • Shape 本质上是 Tuple[int, ...],或许还有一些额外的灵活性来考虑动态形状。

  • NamedShapeShape 的扩展,允许使用 JAX 内部使用的命名形状。

  • 等等。

我们还将探讨是否可以用将 ndarray 定义为 Array 或类似对象的别名来取代当前 jax.numpy.ndarray 的实现。

倾向于简单#

除了在 jax.typing 中收集的常用类型协议之外,我们应该倾向于简单。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而是使用简单的联合,例如 Union[simple_type, Any],如果 API 的完整类型规范无法简洁地指定。这是一种折衷方案,可以实现 1 级和 2 级注释的目标,同时为了避免不必要的复杂性而放弃 3 级注释。

避免不稳定的类型机制#

为了不增加不必要的开发摩擦(由于内部/外部 CI 差异),我们希望在我们使用的类型注释结构中保持保守:特别是,对于最近引入的机制,例如 ParamSpecPEP 612)和可变类型泛型(PEP 646),我们希望等到 mypy 和其他工具中的支持成熟稳定后再依赖它们。

这其中一个影响是,目前,当函数被 JAX 变换(如 jitvmapgrad 等)修饰时,JAX 实际上会剥离所有已修饰函数的注释。虽然这很不幸,但在撰写本文时,mypy 在 ParamSpec 提供的潜在解决方案方面存在一连串的不兼容性问题(参见 ParamSpec mypy bug tracker),因此我们认为它目前尚未准备好在 JAX 中完全采用。一旦对此类功能的支持稳定下来,我们将在未来重新审视这个问题。

同样,目前我们将避免添加 jaxtyping 项目提供的更复杂和细粒度的数组类型注释。这是一个我们可以在未来重新考虑的决定。

Array 类型设计注意事项#

如上所述,由于 JAX 广泛使用鸭子类型,即在 jax 变换中传递和返回 Tracer 对象代替实际数组,因此 JAX 中数组的类型标注提出了独特的挑战。这变得越来越令人困惑,因为用于类型标注的对象通常与用于运行时实例检查的对象重叠,并且可能对应或不对应于相关对象的实际类型层次结构。对于 JAX,我们需要为两种上下文提供鸭子类型的对象:静态类型标注运行时实例检查

以下讨论将假设 jax.Array 是设备上数组的运行时类型,这目前还不是这种情况,但一旦 #12016 中的工作完成,就会是这种情况。

静态类型标注#

我们需要提供一个可以用于鸭子类型类型标注的对象。假设我们暂时将这个对象称为 ArrayAnnotation,我们需要一个解决方案,该解决方案可以使 mypypytype 满足以下情况

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法实现,例如

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,该文件声明 TracerArray 应被视为 ArrayAnnotation 的子类。

  • 重构 ArrayTracer,使 ArrayAnnotation 成为两者的真正基类。

运行时实例检查#

我们还必须提供一个可用于鸭子类型运行时 isinstance 检查的对象。假设我们暂时将这个对象称为 ArrayInstance,我们需要一个能够通过以下运行时检查的解决方案

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同样,有几种机制可以用于此

  • 重写 type(ArrayInstance).__instancecheck__ 以对 ArrayTracer 对象返回 True;这就是当前实现 jnp.ndarray 的方式(来源)。

  • ArrayInstance 定义为抽象基类,并动态将其注册到 ArrayTracer

  • 重构 ArrayTracer,使 ArrayInstance 成为 ArrayTracer 的真正基类

我们需要做出的一个决定是,ArrayAnnotationArrayInstance 应该是相同还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dicttyping.List 用于标注,而内置的 dictlist 则用于实例检查。然而,在较新的 Python 版本中,DictList弃用,取而代之的是使用 dictlist 来进行标注和实例检查。

遵循 NumPy 的引导#

在 NumPy 中,np.typing.NDArray 用于类型标注,而 np.ndarray 用于实例检查(以及数组类型标识)。鉴于此,符合 NumPy 的先例并实现以下内容可能是合理的

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.NDArray 是用于鸭子类型数组标注的对象。

  • jax.numpy.ndarray 是用于鸭子类型数组实例检查的对象。

这对于 NumPy 的高级用户来说可能感觉很自然,但是这种三岔口可能会让人感到困惑:选择哪个用于实例检查和标注并不立即清楚。

统一实例检查和标注#

另一种方法是通过上述重写机制统一类型检查和标注。

选项 1:部分统一#

部分统一可能如下所示:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.Array 是用于鸭子类型数组注释的对象(通过 ArrayTracer 上的 .pyi 接口)。

  • jax.typing.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在这种方法中,为了向后兼容,jax.numpy.ndarray 将成为 jax.typing.Array 的简单别名。

选项 2:通过重写实现完全统一#

或者,我们可以选择通过重写实现完全统一。

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于鸭子类型数组注释的对象(通过 Tracer 上的 .pyi 接口)。

  • jax.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在这里,为了向后兼容,jax.numpy.ndarray 将成为 jax.Array 的简单别名。

选项 3:通过类层次结构实现完全统一#

最后,我们可以选择通过重构类层次结构并将鸭子类型替换为 OOP 对象层次结构来实现完全统一。

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于数组类型注释的对象,通过确保 Tracer 继承自 jax.Array

  • jax.Array 也是用于实例检查的对象,通过相同的机制。

在这里,jnp.ndarray 可以是 jax.Array 的别名。从某种意义上说,最后一种方法是最纯粹的,但从 OOP 设计的角度来看有点勉强(Tracer Array 吗?)。

选项 4:通过类层次结构实现部分统一#

我们可以通过使 Tracer 和设备上数组的类继承自一个公共基类来使类层次结构更合理。例如:

  • jax.ArrayTracer 以及设备上数组实际类型的基类,后者可能是 jax._src.ArrayImpl 或类似的东西。

  • jax.Array 是用于数组类型注释的对象。

  • jax.Array 也是用于实例检查的对象。

在这里,jnp.ndarray 将是 Array 的别名。从 OOP 的角度来看,这可能更纯粹,但与选项 2 和 3 相比,它放弃了 type(x) is jax.Array 将求值为 True 的概念。

评估#

考虑到每种潜在方法的总体优缺点:

  • 从用户的角度来看,统一的方法(选项 2 和 3)可以说是最好的,因为它们消除了记住用于实例检查或注释的对象的认知开销:您只需要知道 jax.Array

  • 但是,选项 2 和 3 都引入了一些奇怪和/或令人困惑的行为。选项 2 依赖于可能令人困惑的实例检查重写,对于 pybind11 中定义的类,这些重写 支持不佳。选项 3 要求 Tracer 成为数组的子类。这打破了继承模型,因为它要求 Tracer 对象携带 Array 对象的所有负担(数据缓冲区、分片、设备等)。

  • 选项 4 在 OOP 意义上更纯粹,并且避免了对典型实例检查或类型注释行为进行任何重写的需要。缺点是,设备上数组的实际类型变成了单独的东西(这里是 jax._src.ArrayImpl)。但是绝大多数用户永远不必直接接触此私有实现。

这里有不同的权衡,但在讨论之后,我们选择了选项 4 作为我们的前进方向。

实施计划#

为了推进类型注释,我们将执行以下操作:

  1. 迭代此 JEP 文档,直到开发人员和利益相关者都认可。

  2. 创建一个私有的 jax._src.typing(目前不提供任何公共 API),并在其中放入上述的第一级简单类型。

    • 暂时将 Array = Any 作为别名,因为这需要更多的思考。

    • ArrayLike:作为正常 jax.numpy 函数的输入的有效类型的联合。

    • DType / DTypeLike(注意:numpy 使用驼峰式 DType;为了易于使用,我们应该遵循此约定)。

    • Shape / NamedShape / ShapeLike

    这方面的工作已在 #12300 中开始。

  3. 开始开发遵循上一节中选项 4 的 jax.Array 基类。最初,这将用 Python 定义,并使用 jnp.ndarray 实现中当前存在的动态注册机制,以确保 isinstance 检查的正确行为。每个跟踪器和类似数组的类的 pyi 重写将确保类型注释的正确行为。jnp.ndarray 可以变成 jax.Array 的别名。

  4. 作为测试,使用这些新的类型定义,根据上述指南,全面注释 jax.lax 中的函数。

  5. 一次添加一个模块的其他注释,重点是公共 API 函数。

  6. 同时,开始在 pybind11 中重新实现 jax.Array 基类,以便 ArrayImplTracer 可以从中继承。使用 pyi 定义以确保静态类型检查器识别该类的适当属性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全落地,请删除这些临时的 Python 实现。

  8. 全部完成后,创建一个公共的 jax.typing 模块,该模块将使上述类型可供用户使用,并提供使用 JAX 的代码的注释最佳实践文档。

我们将在 #12049 中跟踪此工作,此 JEP 的编号由此而来。