形状多态性#

当 JAX 在 JIT 模式下使用时,函数会针对每种输入类型和形状的组合进行追踪、降级为 StableHLO 并编译。在导出函数并在另一个系统上反序列化后,我们不再有可用的 Python 源代码,因此我们无法重新追踪和重新降级它。形状多态是 JAX 导出的一个特性,允许一些导出的函数用于整个输入形状族。这些函数在导出期间被追踪和降级一次,并且 Exported 对象包含能够在许多具体的输入形状上编译和执行函数所需的信息。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,如下例所示

>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x):  # f: f32[a, b]
...   return jnp.concatenate([x, x], axis=1)

>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")

>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)

>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)

>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)

请注意,此类函数仍然会根据需要针对它们被调用的每个具体输入形状重新编译。仅保存追踪和降级过程。

在上面的示例中,jax.export.symbolic_shape() 用于将符号形状的字符串表示形式解析为维度表达式对象(类型为 _DimExpr),这些对象可以代替整数常量来构造形状。维度表达式对象重载了大多数整数运算符,因此在大多数情况下您可以像使用整数常量一样使用它们。有关详细信息,请参阅 使用维度变量进行计算

此外,我们还提供了 jax.export.symbolic_args_specs(),该函数可用于基于多态形状规范构造 jax.ShapeDtypeStruct 对象的 PyTree

>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
...  return x + y

>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))

请注意,多态形状规范 "a, ..." 包含占位符 ...,该占位符将从参数 (x, y) 的具体形状中填充。占位符 ... 表示 0 个或多个维度,而占位符 _ 表示一个维度。jax.export.symbolic_args_specs() 支持参数的 PyTree,这些参数用于填充 dtypes 和任何占位符。该函数将构造一个参数规范的 PyTree(jax.ShapeDtypeStruct),该 PyTree 与传递给它的参数的结构匹配。多态形状规范可以是 PyTree 前缀,在这种情况下,一个规范应适用于多个参数,如上面的示例所示。请参阅 如何将可选参数与参数匹配

一些形状规范的示例

  • ("(b, _, _)", None) 可用于具有两个参数的函数,第一个参数是 3D 数组,具有一个应为符号的批次前导维度。第一个参数的其他维度和第二个参数的形状根据实际参数进行专门化。请注意,如果第一个参数是 3D 数组的 PyTree,所有 3D 数组都具有相同的前导维度但可能具有不同的尾随维度,则相同的规范也可以工作。第二个参数的值 None 表示该参数不是符号的。等效地,可以使用 ...

  • ("(batch, ...)", "(batch,)") 指定两个参数具有匹配的前导维度,第一个参数的秩至少为 1,第二个参数的秩为 1。

形状多态性的正确性#

我们希望相信,当为任何适用的具体形状编译和执行时,导出的程序会产生与原始 JAX 程序相同的结果。更确切地说

对于任何 JAX 函数 f 和任何包含符号形状的参数规范 arg_spec,以及任何形状与 arg_spec 匹配的具体参数 arg

  • 如果 JAX 本机执行在具体参数上成功:res = f(arg)

  • 并且如果使用符号形状导出成功:exp = export.export(f)(arg_spec)

  • 则编译并运行导出的内容将以相同的结果成功:res == exp.call(arg)

至关重要的是要理解 f(arg) 可以自由地重新调用 JAX 追踪机制,并且实际上它对每个不同的具体 arg 形状都这样做,而执行 exp.call(arg) 不能再使用 JAX 追踪(此执行可能发生在 f 的源代码不可用的环境中)。

确保这种形式的正确性很困难,并且在最困难的情况下导出失败。本章的其余部分介绍了如何处理这些失败。

使用维度变量进行计算#

JAX 会跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 会将它们计算为包含维度变量的符号维度表达式。维度变量表示大于或等于 1 的整数值。符号表达式可以表示对维度表达式和整数intnp.int 或任何可由 operator.index 转换的内容)应用算术运算符(加、减、乘、整除、模,包括 NumPy 变体 np.sumnp.prod 等)的结果。然后,这些符号维度可以在 JAX 原语和 API 的形状参数中使用,例如在 jnp.reshapejnp.arange、切片索引等中使用。

例如,在以下展平 2D 数组的代码中,计算 x.shape[0] * x.shape[1] 将符号维度 4 * b 计算为新形状

>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)

可以使用 jnp.array(x.shape[0]) 甚至 jnp.array(x.shape) 将维度表达式显式转换为 JAX 数组。这些操作的结果可以用作常规 JAX 数组,但不能再用作形状中的维度,例如,在 reshape

>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)

>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))  
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].

当符号维度用于与非整数(例如,floatnp.floatnp.ndarray 或 JAX 数组)的算术运算中时,它会自动使用 jnp.array 转换为 JAX 数组。例如,在下面的函数中,由于 x.shape[0] 的所有出现都涉及与非整数标量或 JAX 数组的操作,因此它们会隐式转换为 jnp.array(x.shape[0])

>>> exp = export.export(jax.jit(
...     lambda x: (5. + x.shape[0],
...                x.shape[0] - np.arange(5, dtype=jnp.int32),
...                x + x.shape[0] + jnp.sin(x.shape[0]))))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
 ShapedArray(int32[5]),
 ShapedArray(float32[b], weak_type=True))

>>> exp.call(jnp.ones((3,), jnp.int32))
 (Array(8., dtype=float32, weak_type=True),
  Array([ 3, 2, 1, 0, -1], dtype=int32),
  Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))

另一个典型示例是计算平均值(观察 x.shape[0] 如何自动转换为 JAX 数组)

>>> exp = export.export(jax.jit(
...     lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)

存在形状多态性时的错误#

大多数 JAX 代码都假设 JAX 数组的形状是整数元组,但使用形状多态性时,某些维度可能是符号表达式。这可能会导致一些错误。例如,我们可能会遇到常见的 JAX 形状检查错误

>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
...     jax.ShapeDtypeStruct((v,), dtype=np.int32),
...     jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).

>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
...     jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).

我们可以通过指定参数具有形状 (v, v) 来修复上面的 matmul 示例。

部分支持符号维度的比较#

在 JAX 内部,存在许多涉及形状的相等和不等比较,例如,用于执行形状检查,甚至用于为某些原语选择实现。比较支持如下

  • 如果两个符号维度在维度变量的所有估值下表示相同的值,则支持相等比较,则相等比较的计算结果为 True,例如,对于 b + b == 2*b;否则,相等比较的计算结果为 False。有关此行为重要后果的讨论,请参阅下文

  • 不等比较始终是相等比较的否定。

  • 部分支持不等式,类似于部分相等的情况。然而,在这种情况下,我们会考虑到维度变量的取值范围是严格的正整数。例如,b >= 1b >= 02 * a + b >= 3True,而 b >= 2a >= ba - b >= 0 则是不确定的,会抛出异常。

如果比较操作无法解析为布尔值,我们会抛出 InconclusiveDimensionOperation 异常。例如:

import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

如果您遇到了 InconclusiveDimensionOperation,您可以尝试以下几种策略:

  • 如果您的代码使用了内置的 maxmin,或者 np.maxnp.min,那么您可以将它们替换为 core.max_dimcore.min_dim,这样做可以将不等式比较延迟到编译时,届时形状将变得已知。

  • 尝试使用 core.max_dimcore.min_dim 重写条件语句。例如,可以将 d if d > 0 else 0 写成 core.max_dim(d, 0)

  • 尝试重写代码,使其减少对维度必须为整数的依赖,而更多地依赖于符号维度可以像整数一样用于大多数算术运算的事实。例如,可以将 int(d) + 5 写成 d + 5

  • 指定符号约束,如下所述。

用户指定的符号约束#

默认情况下,JAX 假设所有维度变量的取值都大于等于 1,并尝试从中推导出其他简单的不等式,例如:

  • a + 2 >= 3,

  • a * 2 >= 1,

  • a + b + c >= 3,

  • a // 4 >= 0a**2 >= 1,等等。

如果您更改符号形状规范以添加维度大小的**隐式**约束,则可以避免一些不等式比较失败。例如:

  • 您可以使用 2*b 来表示维度,将其约束为偶数且大于等于 2。

  • 您可以使用 b + 15 来表示维度,将其约束为至少为 16。例如,以下代码在没有 + 15 部分的情况下会失败,因为 JAX 会想要验证切片大小是否最多与轴大小相同。

>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
...    jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))

这种隐式符号约束用于决定比较,并在编译时进行检查,如下文所述

您还可以指定**显式**符号约束

>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
...                              constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
...    jax.ShapeDtypeStruct((a, b), dtype=np.int32))

这些约束与隐式约束一起形成合取。您可以指定 >=<=== 约束。目前,JAX 对使用符号约束进行推理的支持有限

  • 您从变量大于等于或小于等于常量的形式的约束中获得最大收益。例如,从 a >= 16b >= 8 的约束中,我们可以推断出 a + 2*b >= 32

  • 当约束涉及更复杂的表达式时,您获得的功率有限。例如,从 a >= b + 8 中,我们可以推断出 a - b >= 8,但不能推断出 a >= 9。我们将来可能会在这一领域有所改进。

  • 相等约束被视为重写规则:每当遇到 == 左侧的符号表达式时,它将被重写为右侧的表达式。例如,floordiv(a, b) == c 的工作原理是将所有出现的 floordiv(a, b) 替换为 c。相等约束的左侧顶级不能包含加法或减法。有效的左侧示例包括 a * b,或 4 * a,或 floordiv(a + c, b)

>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
...                                    constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c

>>> a * b * b
b*d + b*c

符号约束还可以帮助解决 JAX 推理机制中的限制。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3(即符号表达式 mod(b, 3))小于等于轴大小 b。对于 b 的所有严格正值,这恰好为真,但这不是 JAX 的符号比较规则可以证明的。因此,以下代码会引发错误

from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

这里的一个选择是将代码限制为仅在轴大小为 3 的倍数的情况下工作(通过将形状中的 b 替换为 3*b)。然后,JAX 将能够将模运算 mod(3*b, 3) 简化为 0。另一个选择是添加一个符号约束,其约束条件恰好是 JAX 尝试证明的无法确定的不等式

>>> b, = export.symbolic_shape("b",
...                            constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))

与隐式约束一样,显式符号约束在编译时进行检查,使用与下文所述相同的机制。

符号维度作用域#

符号约束存储在 αn jax.export.SymbolicScope 对象中,该对象在每次调用 jax.export.symbolic_shapes() 时隐式创建。您必须小心,不要混合使用来自不同作用域的符号表达式。例如,以下代码将失败,因为 a1a2 使用不同的作用域(由不同的 jax.export.symbolic_shape() 调用创建)

>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))

>>> a1 + a2  
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected  scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
  a >= 8

来自单个 jax.export.symbolic_shape() 调用的符号表达式共享一个作用域,并且可以在算术运算中混合使用。结果也将共享相同的范围。

您可以重用作用域

>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope)  # Reuse the scope of `a`

>>> a + b  # Allowed
b + a

您还可以显式创建作用域

>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d  # Allowed
d + c

JAX 跟踪使用部分由形状键控的缓存,如果使用不同的作用域,则打印相同的符号形状将被视为不同。

相等比较的注意事项#

对于 b + 1 == bb == 0,相等比较会返回 False(在这种情况下,可以确定对于维度变量的所有值,维度都是不同的),但对于 b == 1a == b 也会返回 False。这是不合理的,我们应该抛出 core.InconclusiveDimensionOperation 异常,因为在某些估值下,结果应该是 True,而在其他估值下,结果应该是 False。我们选择使相等性成为完全的,从而允许不合理性,因为否则,在哈希维度表达式或包含它们的对象的(形状、core.AbstractValuecore.Jaxpr)时,我们可能会在存在哈希冲突的情况下得到虚假的错误。除了哈希错误外,部分相等语义还会导致以下表达式 b == a or b == bb in [a, b] 出现错误,即使当我们更改比较的顺序时,也会避免该错误。

即使这样处理相等性,if x.shape[0] != 1: raise NiceErrorMessage 形式的代码也是合理的,但 if x.shape[0] != 1: return 1 形式的代码是不合理的。

维度变量必须可以从输入形状中求解#

目前,当调用导出的对象时,传递维度变量的值的唯一方法是通过数组参数的形状间接传递。例如,可以在调用站点从类型为 f32[b] 的第一个参数的形状推断出 b 的值。这适用于大多数用例,并且它反映了 JIT 函数的调用约定。

有时,您可能希望导出一个由整数值参数化的函数,该整数值确定程序中的某些形状。例如,我们可能希望导出下面定义的函数 my_top_k,该函数由 k 的值参数化,该值决定了结果的形状。以下尝试将导致错误,因为无法从输入 x: i32[4, 10] 的形状推导出维度变量 k

>>> def my_top_k(k, x):  # x: i32[4, 10], k <= 10
...   return lax.top_k(x, k)[0]  # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))

>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])

>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])

>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)  
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

将来,除了通过输入形状隐式传递外,我们可能会添加额外的机制来传递维度变量的值。同时,上述用例的解决方法是将函数参数 k 替换为形状为 (0, k) 的数组,以便可以从数组的输入形状推导出 k。第一个维度为 0,以确保整个数组为空,并且在调用导出的函数时不会产生性能损失。

>>> def my_top_k_with_dimensions(dimensions, x):  # dimensions: i32[0, k], x: i32[4, 10]
...   return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
...     jax.ShapeDtypeStruct((0, k), dtype=np.int32),
...     x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))

>>> exp.out_avals[0]
ShapedArray(int32[4,k])

>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

当某些维度变量出现在输入形状中,但在 JAX 当前无法求解的非线性表达式中时,您也可能会收到错误。

>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
...    jax.ShapeDtypeStruct((a * a,), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].

形状断言错误#

JAX 假设维度变量的范围是严格正整数,并且当为具体的输入形状编译代码时会检查此假设。

例如,给定符号输入形状 (b, b, 2*d),当使用实际参数 arg 调用时,JAX 将生成代码以检查以下断言

  • arg.shape[0] >= 1

  • arg.shape[1] == arg.shape[0]

  • arg.shape[2] % 2 == 0

  • arg.shape[2] // 2 >= 1

例如,这是当我们在形状为 (3, 3, 5) 的参数上调用导出时得到的错误

>>> def f(x):  # x: f32[b, b, 2*d]
...   return x
>>> exp = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))   
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
  args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.ac.cn/en/latest/export/shape_poly.html#shape-assertion-errors for more details.

这些错误在编译之前的预处理步骤中出现。

调试#

首先,请参阅 调试 文档。此外,您可以调试形状细化,它是在编译时为具有维度变量或多平台支持的模块调用的。

如果在形状细化期间出现错误,您可以设置 JAX_DUMP_IR_TO 环境变量,以查看形状细化之前的 HLO 模块的转储(名为 ..._before_refine_polymorphic_shapes.mlir)。此模块应已具有静态输入形状。

要启用形状细化的所有阶段的日志记录,您可以在 OSS 中设置环境变量 TF_CPP_VMODULE=refine_polymorphic_shapes=3(在 Google 内部,您传递 --vmodule=refine_polymorphic_shapes=3

# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3