JAX 的类型注解路线图#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可选的函数注解(PEP 3107),这些注解后来在 Python 3.5 版本发布前后被正式用于静态类型检查(PEP 484)。在某种程度上,类型注解和静态类型检查已成为许多 Python 开发工作流程不可或缺的一部分,为此,我们在 JAX API 的许多地方添加了注解。JAX 中类型注解的现状有点零散,并且由于更根本的设计问题,添加更多注解的努力受到了阻碍。本文档尝试总结这些问题,并为 JAX 中类型注解的目标和非目标制定路线图。

为什么我们需要这样的路线图?更好/更全面的类型注解是用户(包括内部和外部用户)的频繁请求。此外,我们经常收到外部用户的拉取请求(例如,PR #9917PR #10322),旨在改进 JAX 的类型注解:审查代码的 JAX 团队成员并不总是清楚这些贡献是否有益,尤其是在他们引入复杂的协议来解决 JAX 对 Python 的使用进行全面注解时固有的挑战。本文档详细介绍了 JAX 在软件包中进行类型注解的目标和建议。

为什么要使用类型注解?#

Python 项目可能希望注解其代码库的原因有很多;我们将在本文档中将它们总结为级别 1、级别 2 和级别 3。

级别 1:将注解作为文档#

最初在 PEP 3107 中引入时,类型注解的部分动机是能够将它们用作函数参数类型和返回类型的简洁内联文档。JAX 长期以来一直以这种方式使用注解;一个例子是创建别名为 Any 的类型名称的常见模式。一个例子可以在 lax/slicing.py 中找到 [source]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

出于静态类型检查的目的,这种使用 Array = Any 进行数组类型注解的方式不对参数值施加任何约束(Any 等同于根本没有注解),但它确实为开发人员提供了一种有用的代码内文档形式。

为了生成文档,别名的名称会丢失(jax.lax.sliceHTML 文档 将操作数报告为类型 Any),因此文档的好处不会超出源代码(尽管我们可以启用一些 sphinx-autodoc 选项来改进这一点:请参阅 autodoc_type_aliases)。

此级别类型注解的一个好处是,使用 Any 注解一个值永远不会出错,因此它将以文档的形式为开发人员和用户提供具体的好处,而无需满足任何特定静态类型检查器的更严格需求的额外复杂性。

级别 2:用于智能自动完成的注解#

许多现代 IDE 都利用类型注解作为 智能代码完成系统的输入。这方面的一个例子是 VSCode 的 Pylance 扩展,它使用 Microsoft 的 pyright 静态类型检查器作为 VSCode 的 IntelliSense 完成的信息来源。

这种类型检查的使用需要比上面使用的简单别名更进一步;例如,知道 slice 函数返回一个名为 ArrayAny 别名不会向代码完成引擎添加任何有用的信息。但是,如果我们使用 DeviceArray 返回类型注解该函数,则自动完成功能将知道如何填充结果的命名空间,从而能够在开发过程中建议更相关的自动完成。

JAX 已开始在一些地方添加此级别的类型注解;一个例子是 jax.random 包中的 jnp.ndarray 返回类型 [source]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在这种情况下,jnp.ndarray 是一个抽象基类,它向前声明 JAX 数组的属性和方法(请参阅源代码),因此 VSCode 中的 Pylance 可以提供此函数结果的所有自动完成功能。这是一个显示结果的屏幕截图

VSCode Intellisense Screenshot

自动完成字段中列出了抽象 ndarray 类声明的所有方法和属性。我们将在下面进一步讨论为什么有必要创建这个抽象类,而不是直接使用 DeviceArray 进行注解。

级别 3:用于静态类型检查的注解#

如今,当考虑 Python 代码中类型注解的目的时,人们首先想到的通常是静态类型检查。虽然 Python 不会进行任何类型的运行时检查,但存在几种成熟的静态类型检查工具,可以将其作为 CI 测试套件的一部分。对于 JAX 来说,最重要的工具如下

  • python/mypy 或多或少是开放 Python 世界中的标准。JAX 目前在 Github Actions CI 检查中对源文件的子集运行 mypy。

  • google/pytype 是 Google 的静态类型检查器,并且 Google 内部依赖于 JAX 的项目经常使用它。

  • microsoft/pyright 作为 VSCode 中用于先前提到的 Pylance 完成的静态类型检查器非常重要。

完整的静态类型检查是所有类型注解应用程序中最严格的,因为它会在您的类型注解不完全正确时显示错误。一方面,这很好,因为您的静态类型分析可能会捕获错误的类型注解(例如,jnp.ndarray 抽象类中缺少 DeviceArray 方法的情况)。

另一方面,这种严格性可能会使类型检查过程在通常依赖于鸭子类型而不是严格的类型安全 API 的软件包中变得非常脆弱。您目前会发现像 #type: ignore(对于 mypy)或 #pytype: disable(对于 pytype)这样的代码注释散落在 JAX 代码库中的数百个地方。这些通常表示出现了类型问题的情况;它们可能是 JAX 类型注解中的不准确之处,或者是静态类型检查器在正确跟踪代码中的控制流方面的能力不准确。有时,它们是由于 pytype 或 mypy 行为中的真实且细微的错误造成的。在极少数情况下,它们可能是由于 JAX 使用的 Python 模式难以甚至不可能用 Python 的静态类型注解语法来表达的事实造成的。

JAX 的类型注解挑战#

JAX 当前具有混合不同风格的类型注解,并且旨在用于上面讨论的所有三个级别的类型注解。部分原因是 JAX 的源代码对 Python 的类型注解系统提出了一些独特的挑战。我们将在此处概述它们。

挑战 1:pytype、mypy 和开发人员摩擦#

JAX 目前面临的一个挑战是,软件包开发必须满足两个不同的静态类型检查系统 pytype(由内部 CI 和内部 Google 项目使用)和 mypy(由外部 CI 和外部依赖项使用)的约束。尽管两个类型检查器在行为上具有广泛的重叠,但每个类型检查器都呈现出其独特的极端情况,JAX 代码库中大量的 #type: ignore#pytype: disable 语句证明了这一点。

这会在开发中造成摩擦:内部贡献者可能会迭代直到测试通过,但却发现在导出时,他们经过 pytype 批准的代码却不符合 mypy 的要求。对于外部贡献者,情况通常相反:最近的一个例子是 #9596,该问题在通过内部 Google pytype 检查失败后不得不回滚。每次我们将类型注释从 1 级(到处都是 Any)移动到 2 级或 3 级(更严格的注释)时,都会增加这种令人沮丧的开发者体验的可能性。

挑战 2:数组鸭子类型#

为 JAX 代码添加注释的一个特殊挑战是它大量使用了鸭子类型。通常,标记为 Array 的函数的输入可能是许多不同的类型之一:JAX 的 DeviceArray,NumPy 的 np.ndarray,NumPy 标量,Python 标量,Python 序列,具有 __array__ 属性的对象,具有 __jax_array__ 属性的对象,或任何形式的 jax.Tracer。因此,像 def func(x: DeviceArray) 这样的简单注释是不够的,并且会导致许多有效用例的误报。这意味着 JAX 函数的类型注释不会简短或琐碎,我们必须有效地开发一组类似于 numpy.typing中的 JAX 特定的类型扩展。

挑战 3:转换和装饰器#

JAX 的 Python API 严重依赖于函数转换(jit()vmap()grad() 等),这种类型的 API 对静态类型分析提出了特殊的挑战。装饰器的灵活注释一直是 mypy 包中的一个长期存在的问题,直到最近才通过引入 ParamSpec 解决,该问题在 PEP 612 中讨论并在 Python 3.10 中添加。由于 JAX 遵循 NEP 29,它在 2024 年中期之后的一段时间内都无法依赖 Python 3.10 的功能。在此期间,协议可以用作此问题的部分解决方案(JAX 在 #9950 中为 jit 和其他方法添加了此功能),并且可以通过 typing_extensions 包使用 ParamSpec(#9999 中有一个原型),但这目前揭示了 mypy 中的基本错误(请参阅 python/mypy#12593)。总而言之:尚不清楚 JAX 的函数转换的 API 是否可以在当前 Python 类型注释工具的约束范围内进行适当的注释。

挑战 4:数组注释缺乏粒度#

这里的另一个挑战在 Python 中所有面向数组的 API 中都很常见,并且已经成为 JAX 讨论的一部分好几年了(请参阅 #943)。类型注释与对象的 Python 类或类型有关,而在基于数组的语言中,类的属性通常更重要。在 NumPy、JAX 和类似包的情况下,我们通常希望注释特定的数组形状和数据类型。

例如,jnp.linspace 函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注释不引发误报,我们必须允许这些参数为任意数组。另一个例子是 jax.random.choice 的第二个参数,当 shape=() 时,它的 dtype=int。Python 计划通过可变类型泛型启用具有这种粒度的类型注释(请参阅 PEP 646,计划用于 Python 3.11),但像 ParamSpec 一样,对该功能的支持将需要一段时间才能稳定下来。

有一些第三方项目可能在此期间有所帮助,特别是 google/jaxtyping,但这使用了非标准的注释,可能不适合注释 JAX 核心库本身。总而言之,数组类型粒度挑战不如其他挑战那么严重,因为主要的影响是类似数组的注释将不如它们可能的那样具体。

挑战 5:从 NumPy 继承的不精确的 API#

JAX 的大部分面向用户的 API 是从 jax.numpy 子模块中的 NumPy 继承的。NumPy 的 API 是在静态类型检查成为 Python 语言的一部分之前开发的,并遵循 Python 的历史建议,使用鸭子类型/EAFP 编码风格,不鼓励在运行时进行严格的类型检查。作为具体示例,请考虑 numpy.tile() 函数,该函数定义如下

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

这里的意图reps 包含 intint 值的序列,但实现允许 tup 为任何可迭代对象。在向这种鸭子类型代码添加注释时,我们可以采用两种方式之一

  1. 我们可以选择注释函数 API 的意图,这里可能类似于 reps: Union[int, Sequence[int]]

  2. 相反,我们可以选择注释函数的实现,这里可能类似于 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中 ConvertibleToInt 是一个特殊的协议,涵盖了我们的函数将输入转换为整数的确切机制(即通过 __int____index____array__ 等)。另请注意,从严格意义上讲,Iterable 在这里是不够的,因为 Python 中存在一些对象,它们可以作为可迭代对象进行鸭子类型化,但不满足针对 Iterable 的静态类型检查(即通过 __getitem__ 而不是 __iter__ 可迭代的对象)。

注释意图(#1)的优点是,注释在传达 API 合约时对用户更有用;而对于开发人员来说,这种灵活性为必要时的重构留出了空间。缺点(特别是对于像 JAX 这样的渐进类型化的 API)是,很可能存在用户代码可以正确运行,但会被类型检查器标记为不正确的代码。对现有鸭子类型 API 进行渐进类型化意味着当前注释隐含地为 Any,因此将其更改为更严格的类型可能会对用户来说是重大更改。

总的来说,注释意图更好地服务于 1 级类型检查,而注释实现更好地服务于 3 级类型检查,而 2 级更像是混合包(在 IDE 中的注释方面,意图和实现都很重要)。

JAX 类型注释路线图#

考虑到这种框架(1/2/3 级)和 JAX 特有的挑战,我们可以开始制定在 JAX 项目中实现一致类型注释的路线图。

指导原则#

对于 JAX 类型注释,我们将遵循以下原则

类型注释的目的#

我们希望尽可能地支持完整的1、2 和 3 级类型注释。特别是,这意味着我们应该对公共 API 函数的输入和输出都具有限制性的类型注释。

注释意图#

JAX 类型注释通常应指示 API 的意图,而不是实现,以便注释有助于传达 API 的合同。这意味着有时在运行时有效的输入可能不会被静态类型检查器识别为有效(一个例子可能是传递任意迭代器代替注释为 Shape = Sequence[int] 的形状)。

输入应该是宽容类型的#

JAX 函数和方法的输入应尽可能宽松地进行类型标注:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。类似地,接受 dtype 的函数不需要 np.dtype 类的实例,而是任何可转换为 dtype 的对象。这可能包括字符串、内置标量类型或标量对象构造函数,如 np.float64jnp.float64。为了使整个包的类型标注尽可能统一,我们将添加一个 jax.typing 模块,其中包含常见的类型规范,从以下广泛类别开始:

  • ArrayLike 将是任何可以隐式转换为数组的对象的联合:例如,jax 数组、numpy 数组、JAX 追踪器以及 python 或 numpy 标量。

  • DTypeLike 将是任何可以隐式转换为 dtype 的对象的联合:例如,numpy dtypes、numpy dtype 对象、jax dtype 对象、字符串和内置类型。

  • ShapeLike 将是任何可以转换为形状的对象的联合:例如,整数或类整数对象的序列。

  • 等等。

请注意,这些通常比 numpy.typing 中使用的等效协议更简单。例如,对于 DTypeLike,JAX 不支持结构化的 dtypes,因此 JAX 可以使用更简单的实现。类似地,在 ArrayLike 中,JAX 通常不支持使用列表或元组代替数组作为输入,因此类型定义将比 NumPy 的类似定义更简单。

输出应严格类型标注#

相反,函数和方法的输出应尽可能严格地进行类型标注:例如,对于返回数组的 JAX 函数,输出应使用类似于 jnp.ndarray 而不是 ArrayLike 的类型进行标注。返回 dtype 的函数应始终标注为 np.dtype,返回形状的函数应始终标注为 Tuple[int] 或严格类型的 NamedShape 等效项。为此,我们将在 jax.typing 中实现上述宽松类型的几种严格类型标注的类似项,即:

  • 用于类型标注的 ArrayNDArray (见下文) 实际上等同于 Union[Tracer, jnp.ndarray],应用于标注数组输出。

  • DTypenp.dtype 的别名,可能还能够表示 JAX 中使用的关键类型和其他泛化。

  • Shape 本质上是 Tuple[int, ...],可能具有一些额外的灵活性,以适应动态形状。

  • NamedShapeShape 的扩展,允许使用 JAX 内部使用的命名形状。

  • 等等。

我们还将探索是否可以使用 Array 或类似的别名来替换当前 jax.numpy.ndarray 的实现,使其成为 ndarray 的别名。

倾向于简单#

除了在 jax.typing 中收集的通用类型协议之外,我们应该倾向于简单。我们应该避免为传递给 API 函数的参数构建过于复杂的协议,而是在无法简洁指定 API 的完整类型规范的情况下,使用简单的联合,例如 Union[simple_type, Any]。这是一个折衷方案,它实现了第 1 级和第 2 级标注的目标,同时为了避免不必要的复杂性而放弃了第 3 级。

避免不稳定的类型机制#

为了不增加不必要的开发摩擦(由于内部/外部 CI 的差异),我们希望在使用的类型标注构造方面保持保守:特别是,对于最近引入的机制,例如 ParamSpec (PEP 612) 和可变类型泛型 (PEP 646),我们希望等到 mypy 和其他工具中的支持成熟稳定后再依赖它们。

这样做的一个影响是,目前,当函数被 JAX 转换(例如 jitvmapgrad 等)修饰时,JAX 将有效地去除被修饰函数的所有标注。虽然这很不幸,但在撰写本文时,mypy 与 ParamSpec 提供的潜在解决方案存在一系列不兼容问题(请参阅 ParamSpec mypy 错误跟踪器),因此我们认为它目前不适合在 JAX 中完全采用。我们将在未来此类功能的支持稳定后重新考虑这个问题。

类似地,目前我们将避免添加 jaxtyping 项目提供的更复杂和细粒度的数组类型标注。我们可以在未来重新考虑这个决定。

Array 类型设计注意事项#

如上所述,由于 JAX 广泛使用鸭子类型(即在 jax 转换中传递和返回 Tracer 对象代替实际数组),因此 JAX 中数组的类型标注提出了独特的挑战。这变得越来越令人困惑,因为用于类型标注的对象通常与用于运行时实例检查的对象重叠,并且可能对应也可能不对应于所讨论对象的实际类型层次结构。对于 JAX,我们需要提供鸭子类型的对象,以用于两种情况:静态类型标注运行时实例检查

以下讨论将假设 jax.Array 是设备上数组的运行时类型,尽管目前还不是这种情况,但一旦 #12016 中的工作完成,就会变成这种情况。

静态类型标注#

我们需要提供一个可用于鸭子类型标注的对象。假设我们暂时将此对象称为 ArrayAnnotation,我们需要一个解决方案,该解决方案在如下情况下满足 mypypytype 的要求

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

这可以通过多种方法来实现,例如

  • 使用类型联合:ArrayAnnotation = Union[Array, Tracer]

  • 创建一个接口文件,声明 TracerArray 应被视为 ArrayAnnotation 的子类。

  • 重构 ArrayTracer,使 ArrayAnnotation 成为两者的真实基类。

运行时实例检查#

我们还必须提供一个可用于鸭子类型运行时 isinstance 检查的对象。假设我们暂时将此对象称为 ArrayInstance,我们需要一个解决方案来通过以下运行时检查

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同样,有几种机制可用于此

  • 覆盖 type(ArrayInstance).__instancecheck__ 以针对 ArrayTracer 对象都返回 True;这正是当前 jnp.ndarray 的实现方式 (source)。

  • ArrayInstance 定义为抽象基类,并将其动态注册到 ArrayTracer

  • 重构 ArrayTracer,使 ArrayInstance 成为 ArrayTracer 的真实基类。

我们需要做出的决定是 ArrayAnnotationArrayInstance 应该是相同的对象还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,存在 typing.Dicttyping.List 用于标注,而内置的 dictlist 用于实例检查。但是,较新的 Python 版本中,DictList弃用,而改用 dictlist 进行标注和实例检查。

遵循 NumPy 的做法#

在 NumPy 中,np.typing.NDArray 用于类型标注,而 np.ndarray 用于实例检查(以及数组类型标识)。鉴于此,遵循 NumPy 的先例并实现以下内容可能是合理的

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.NDArray 是用于鸭子类型数组标注的对象。

  • jax.numpy.ndarray 是用于鸭子类型数组实例检查的对象。

对于 NumPy 的高级用户来说,这可能感觉比较自然,但这种三叉分歧可能会让人困惑:选择哪个用于实例检查和注解并不显而易见。

统一实例检查和注解#

另一种方法是通过上述的重写机制来统一类型检查和注解。

方案 1:部分统一#

部分统一可能如下所示:

  • jax.Array 是设备上数组的实际类型。

  • jax.typing.Array 是用于鸭子类型数组注解的对象(通过 ArrayTracer 上的 .pyi 接口)。

  • jax.typing.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在这种方法中,jax.numpy.ndarray 将成为 jax.typing.Array 的一个简单别名,以实现向后兼容。

方案 2:通过重写实现完全统一#

或者,我们可以选择通过重写来实现完全统一:

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于鸭子类型数组注解的对象(通过 Tracer 上的 .pyi 接口)。

  • jax.Array 也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__ 重写)。

在这里,jax.numpy.ndarray 将成为 jax.Array 的一个简单别名,以实现向后兼容。

方案 3:通过类层次结构实现完全统一#

最后,我们可以选择通过重组类层次结构并将鸭子类型替换为 OOP 对象层次结构来实现完全统一:

  • jax.Array 是设备上数组的实际类型。

  • jax.Array 也是用于数组类型注解的对象,通过确保 Tracer 继承自 jax.Array

  • jax.Array 也是用于实例检查的对象,通过相同的机制。

在这里,jnp.ndarray 可以是 jax.Array 的一个别名。从某种意义上说,最终这种方法是最纯粹的,但从 OOP 设计的角度来看,它有些牵强(Tracer Array 吗?)。

方案 4:通过类层次结构实现部分统一#

我们可以通过使 Tracer 和设备上数组的类继承自一个共同的基类来使类层次结构更加合理。例如:

  • jax.ArrayTracer 以及设备上数组实际类型的基类,后者可能是 jax._src.ArrayImpl 或类似的东西。

  • jax.Array 是用于数组类型注解的对象。

  • jax.Array 也是用于实例检查的对象。

在这里,jnp.ndarray 将是 Array 的一个别名。从 OOP 的角度来看,这可能更纯粹,但与方案 2 和 3 相比,它放弃了 type(x) is jax.Array 将计算为 True 的概念。

评估#

考虑到每种潜在方法的整体优缺点:

  • 从用户的角度来看,统一的方法(方案 2 和 3)可以说是最好的,因为它们消除了记住哪个对象用于实例检查或注解的认知负担:您只需要知道 jax.Array

  • 但是,方案 2 和 3 都引入了一些奇怪和/或令人困惑的行为。方案 2 依赖于实例检查的潜在混淆的重写,这对于在 pybind11 中定义的类支持不佳。方案 3 要求 Tracer 是数组的子类。这打破了继承模型,因为它要求 Tracer 对象携带 Array 对象的所有负担(数据缓冲区、分片、设备等)。

  • 方案 4 在 OOP 意义上更纯粹,并避免了对典型实例检查或类型注解行为进行任何重写的需要。其缺点是设备上数组的实际类型变成了单独的东西(这里是 jax._src.ArrayImpl)。但是绝大多数用户永远不必直接接触这个私有实现。

这里有不同的权衡,但在讨论之后,我们最终选择了方案 4 作为我们前进的方向。

实施计划#

为了推进类型注解,我们将执行以下操作:

  1. 迭代此 JEP 文档,直到开发人员和利益相关者都接受。

  2. 创建一个私有的 jax._src.typing(暂时不提供任何公共 API),并将上面提到的第一级简单类型放入其中。

    • 暂时将 Array = Any 设置为别名,因为这需要更多思考。

    • ArrayLike:作为普通 jax.numpy 函数输入的有效类型联合。

    • DType / DTypeLike(注意:numpy 使用驼峰式 DType;我们应该遵循此约定以方便使用)。

    • Shape / NamedShape / ShapeLike

    这项工作的开端已在 #12300 中完成。

  3. 开始开发遵循上一节方案 4 的 jax.Array 基类。最初,这将以 Python 定义,并使用当前在 jnp.ndarray 实现中找到的动态注册机制,以确保 isinstance 检查的正确行为。每个追踪器和类似数组的类的 pyi 重写将确保类型注解的正确行为。jnp.ndarray 然后可以成为 jax.Array 的别名。

  4. 作为测试,根据上述指南,使用这些新的类型定义来全面注解 jax.lax 中的函数。

  5. 一次添加一个模块的额外注解,重点是公共 API 函数。

  6. 与此同时,开始在 pybind11 中重新实现 jax.Array 基类,以便 ArrayImplTracer 可以继承它。使用 pyi 定义来确保静态类型检查器识别类的适当属性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全落地,请删除这些临时的 Python 实现。

  8. 全部完成后,创建一个公共 jax.typing 模块,使上述类型可供用户使用,并提供使用 JAX 的代码的注解最佳实践文档。

我们将在此 #12049 中跟踪此工作,此 JEP 从中获取其编号。