错误#

此页面列出了您在使用 JAX 时可能遇到的部分错误,以及如何修复它们的典型示例。

class jax.errors.ConcretizationTypeError(tracer, context='')#

当在需要具体值的上下文中使用 JAX Tracer 对象时,会发生此错误(有关 Tracer 的更多信息,请参阅 不同类型的 JAX 值)。在某些情况下,可以通过将有问题的 value 标记为静态轻松修复;在其他情况下,这可能表明您的程序正在执行 JAX 的 JIT 编译模型不支持的操作。

示例

跟踪值,需要静态值

此错误的一个常见原因是使用跟踪值,但需要静态值。例如

>>> from functools import partial
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
...   return x.min(axis)
>>> func(jnp.arange(4), 0)  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().

这通常可以通过将有问题的参数标记为静态来修复

>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return x.min(axis)

>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
形状取决于追踪的值

当您 JIT 编译的计算中的形状取决于追踪量中的值时,也可能会出现此错误。例如

>>> @jit
... def func(x):
...     return jnp.where(x < 0)

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.

这是一个与 JAX 的 JIT 编译模型不兼容的操作示例,该模型要求数组大小在编译时已知。这里,返回数组的大小取决于 x 的内容,此类代码无法 JIT 编译。

在许多情况下,可以通过修改函数中使用的逻辑来解决此问题;例如,这里有一段代码存在类似的问题

>>> @jit
... def func(x):
...     indices = jnp.where(x > 1)
...     return x[indices].sum()

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.

这里是如何以避免创建动态大小的索引数组的方式表达相同操作

>>> @jit
... def func(x):
...   return jnp.where(x > 1, x, 0).sum()

>>> func(jnp.arange(4))
Array(5, dtype=int32)

要了解有关追踪器与普通值、具体值与抽象值之间的更多细微差别,您可能需要阅读 不同类型的 JAX 值.

参数:
  • tracer (core.Tracer)

  • context (str)

class jax.errors.KeyReuseError(message)#

当以不安全的方式重复使用 PRNG 密钥时,会出现此错误。仅当 jax_debug_key_reuse 设置为 True 时才会检查密钥重复使用。

这是一个导致此类错误的简单代码示例

>>> with jax.debug_key_reuse(True):  
...   key = jax.random.key(0)
...   value = jax.random.uniform(key)
...   new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError                             Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

这种密钥重复使用是有问题的,因为 JAX PRNG 是无状态的,并且必须手动拆分密钥;有关此内容的更多信息,请参阅 Sharp Bits: Random Numbers.

参数:

message (str)

class jax.errors.NonConcreteBooleanIndexError(tracer)#

当程序尝试在追踪的索引操作中使用非具体布尔索引时,会出现此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此必须小心使用布尔掩码。某些通过布尔掩码实现的逻辑在 jax.jit() 函数中根本不可能;在其他情况下,可以以 JIT 兼容的方式重新表达该逻辑,通常使用 where() 的三参数版本。

以下是一些出现此错误的情况示例。

通过布尔掩码构建数组

这最常见于尝试在 JIT 上下文中通过布尔掩码创建数组时。例如

>>> import jax
>>> import jax.numpy as jnp

>>> @jax.jit
... def positive_values(x):
...   return x[x > 0]

>>> positive_values(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

此函数尝试返回输入数组中的所有正值;除非 x 被标记为静态,否则无法在编译时确定此返回数组的大小,因此无法在 JIT 编译下执行此类操作。

可重新表达的布尔逻辑

虽然不支持直接创建动态大小的数组,但在许多情况下,可以将计算的逻辑重新表达为 JIT 兼容的操作。例如,这里还有另一个由于相同原因在 JIT 下失败的函数

>>> @jax.jit
... def sum_of_positive(x):
...   return x[x > 0].sum()

>>> sum_of_positive(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

但是,在这种情况下,有问题的数组只是一个中间值,我们可以用 JIT 兼容的三参数版本的 jax.numpy.where() 来表达相同的逻辑

>>> @jax.jit
... def sum_of_positive(x):
...   return jnp.where(x > 0, x, 0).sum()

>>> sum_of_positive(jnp.arange(-5, 5))
Array(10, dtype=int32)

用三参数 where() 替换布尔掩码是解决此类问题的常见方法。

对 JAX 数组进行布尔索引

另一个经常出现此错误的情况是使用布尔索引,例如使用 .at[...].set(...)。这里有一个简单的示例

>>> @jax.jit
... def manual_clip(x):
...   return x.at[x < 0].set(0)

>>> manual_clip(jnp.arange(-2, 2))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])

此函数尝试将小于零的值设置为标量填充值。如上所述,这可以通过用 where() 来重新表达逻辑来解决

>>> @jax.jit
... def manual_clip(x):
...   return jnp.where(x < 0, 0, x)

>>> manual_clip(jnp.arange(-2, 2))
Array([0, 0, 0, 1], dtype=int32)
参数:

tracer (core.Tracer)

class jax.errors.TracerArrayConversionError(tracer)#

当程序尝试将 JAX Tracer 对象转换为标准 NumPy 数组时,会出现此错误(有关 Tracer 的更多信息,请参阅 不同类型的 JAX 值)。它通常发生在以下几种情况之一。

在 JAX 变换中使用非 JAX 函数

如果您尝试在 JAX 变换(jit()grad()jax.vmap() 等)中使用非 JAX 库(如 numpyscipy),则可能会出现此错误。例如

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x):
...   return np.sin(x)

>>> func(np.arange(4))  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[4]

在这种情况下,您可以通过使用 jax.numpy.sin() 代替 numpy.sin() 来修复此问题

>>> import jax.numpy as jnp
>>> @jit
... def func(x):
...   return jnp.sin(x)

>>> func(jnp.arange(4))
Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

另请参阅 External Callbacks,了解有关从变换后的 JAX 代码回调到主机端计算的选项。

用追踪器对 numpy 数组进行索引

如果此错误出现在涉及数组索引的行上,则可能是因为被索引的数组 x 是一个标准的 numpy.ndarray,而索引 idx 是追踪的 JAX 数组。例如

>>> x = np.arange(10)

>>> @jit
... def func(i):
...   return x[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[0]

根据上下文,您可以通过将 numpy 数组转换为 JAX 数组来修复此问题

>>> @jit
... def func(i):
...   return jnp.asarray(x)[i]

>>> func(0)
Array(0, dtype=int32)

或通过将索引声明为静态参数

>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
...   return x[i]

>>> func(0)
Array(0, dtype=int32)

要了解有关追踪器与普通值、具体值与抽象值之间的更多细微差别,您可能需要阅读 不同类型的 JAX 值.

参数:

tracer (core.Tracer)

class jax.errors.TracerBoolConversionError(tracer)#

当在需要布尔值的上下文中使用 JAX 中的追踪值时,会出现此错误(有关 Tracer 的更多信息,请参阅 不同类型的 JAX 值)。

布尔转换可能是显式的(例如 bool(x))或隐式的,通过使用控制流(例如 if x > 0while x)、使用 Python 布尔运算符(例如 z = x and yz = x or yz = not x)或使用它们的函数(例如 z = max(x, y)z = min(x, y) 等)。

在某些情况下,可以通过将追踪值标记为静态来轻松修复此问题;在其他情况下,这可能表明您的程序正在执行 JAX 的 JIT 编译模型不支持的操作。

示例

在控制流中使用追踪值

这经常出现的一种情况是,当追踪值在 Python 控制流中使用时。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
...   return x if x.sum() < y.sum() else y

>>> func(jnp.ones(4), jnp.zeros(4))  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]

我们可以将两个输入 xy 标记为静态,但这会违背在这里使用 jax.jit() 的目的。另一个选择是根据三项式 jax.numpy.where() 重新表达 if 语句

>>> @jit
... def func(x, y):
...   return jnp.where(x.sum() < y.sum(), x, y)

>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)

对于包括循环在内的更复杂的控制流,请参阅 控制流运算符.

对追踪值进行控制流

此错误的另一个常见原因是,如果您不小心追踪了一个布尔标志。例如

>>> @jit
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

这里,因为标志 normalize 被追踪,所以它不能在 Python 控制流中使用。在这种情况下,最好的解决方案可能是将此值标记为静态

>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)

有关 static_argnums 的更多信息,请参阅 jax.jit() 的文档。

使用非 JAX 感知函数

导致此错误的另一个常见原因是在 JAX 代码中使用非 JAX 感知函数。例如

>>> @jit
... def func(x):
...   return min(x, 0)
>>> func(2)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

在这种情况下,错误发生是因为 Python 的内置 min 函数与 JAX 变换不兼容。这可以通过用 jnp.minumum 替换它来修复

>>> @jit
... def func(x):
...   return jnp.minimum(x, 0)
>>> print(func(2))
0

要了解有关追踪器与普通值、具体值与抽象值之间的更多细微差别,您可能需要阅读 不同类型的 JAX 值.

参数:

tracer (core.Tracer)

class jax.errors.TracerIntegerConversionError(tracer)#

当在需要 Python 整数的上下文中使用 JAX Tracer 对象时,可能会出现此错误(有关 Tracer 的更多信息,请参阅 不同类型的 JAX 值)。它通常发生在以下几种情况中。

传递追踪器代替整数

如果您尝试将追踪值传递给需要静态整数参数的函数,则可能会出现此错误;例如

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(4), 0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

当这种情况发生时,解决方案通常是将有问题的参数标记为静态。

>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(10), 0)
[Array([0, 1, 2, 3, 4], dtype=int32),
 Array([5, 6, 7, 8, 9], dtype=int32)]

另一种方法是将转换应用于封装要保护的参数的闭包,可以通过以下手动方式或使用 functools.partial()

>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]

请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,这就是为什么首选 static_argnums 的原因。

用 Tracer 索引列表

如果您尝试用追踪的量来索引 Python 列表,则可能会出现此错误。例如

>>> import jax.numpy as jnp
>>> from jax import jit

>>> L = [1, 2, 3]

>>> @jit
... def func(i):
...   return L[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

根据上下文,您通常可以通过将列表转换为 JAX 数组来解决此问题

>>> @jit
... def func(i):
...   return jnp.array(L)[i]

>>> func(0)
Array(1, dtype=int32)

或通过将索引声明为静态参数

>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
...   return L[i]

>>> func(0)
Array(1, dtype=int32, weak_type=True)

要了解有关追踪器与普通值、具体值与抽象值之间的更多细微差别,您可能需要阅读 不同类型的 JAX 值.

参数:

tracer (core.Tracer)

class jax.errors.UnexpectedTracerError(msg)#

当您使用从函数中泄漏的 JAX 值时,就会发生此错误。泄漏值意味着什么?如果您对存储在 f 之外的某个范围内的函数 f 应用 JAX 转换,并且该转换引用了中间值,则该值被视为已泄漏。泄漏值是一种副作用。(有关避免副作用的更多信息,请参阅 纯函数

当您随后在另一个操作中使用泄漏的值时,JAX 会检测到泄漏,此时它会引发 UnexpectedTracerError。要解决此问题,请避免副作用:如果函数计算了外部范围所需的值,请从转换后的函数中显式返回该值。

具体来说,Tracer 是 JAX 在转换期间(例如,在 jit()pmap()vmap() 等中)函数中间值的内部表示。在转换之外遇到 Tracer 意味着发生了泄漏。

泄漏值的生命周期

考虑以下将值泄漏到外部范围的转换函数示例

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit                   # 1
... def side_effecting(x):
...   y = x + 1            # 3
...   outs.append(y)       # 4

>>> x = 1
>>> side_effecting(x)      # 2
>>> outs[0] + 1            # 5  
Traceback (most recent call last):
    ...
UnexpectedTracerError: Encountered an unexpected tracer.

在此示例中,我们将 Traced 值从内部转换范围泄漏到外部范围。当使用泄漏的值时,我们会收到 UnexpectedTracerError,而不是在值泄漏时。

此示例还演示了泄漏值的生命周期

  1. 函数被转换(在本例中,通过 jit()

  2. 转换后的函数被调用(启动函数的抽象跟踪并将 x 转换为 Tracer

  3. 创建中间值 y,它将在以后泄漏(跟踪函数的中间值也是 Tracer

  4. 值被泄漏(附加到外部范围的列表中,通过侧通道逃离函数)

  5. 使用泄漏的值,并引发 UnexpectedTracerError。

UnexpectedTracerError 消息试图通过包含每个阶段的信息来指向代码中的这些位置。分别

  1. 转换后的函数名称 (side_effecting) 以及启动跟踪的转换 jit()

  2. 重建的堆栈跟踪,其中创建了泄漏的 Tracer,其中包括调用转换后的函数的位置。(When the Tracer was created, the final 5 stack frames were...)。

  3. 从重建的堆栈跟踪中,创建泄漏的 Tracer 的代码行。

  4. 错误消息中没有包含泄漏位置,因为很难确定!JAX 只能告诉您泄漏的值是什么样的(它具有什么形状以及在何处创建)以及它泄漏的边界(转换的名称和转换后的函数的名称)。

  5. 当前错误的堆栈跟踪指向使用值的位置。

可以通过从转换后的函数中返回该值来修复错误

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def not_side_effecting(x):
...   y = x+1
...   return y

>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1  # all good! no longer a leaked value.
Array(3, dtype=int32, weak_type=True)
泄漏检查器

如上文第 2 点和第 3 点所述,JAX 显示了重建的堆栈跟踪,该跟踪指向创建泄漏的值的位置。这是因为 JAX 仅在使用泄漏的值时才引发错误,而不是在值泄漏时。这不是引发此错误的最有用位置,因为您需要知道 Tracer 泄漏的位置才能修复错误。

为了使该位置更容易追踪,您可以使用泄漏检查器。启用泄漏检查器后,一旦 Tracer 泄漏,就会引发错误。(更准确地说,它将在从其泄漏了 Tracer 的转换函数返回时引发错误)

要启用泄漏检查器,您可以使用 JAX_CHECK_TRACER_LEAKS 环境变量或 with jax.checking_leaks() 上下文管理器。

注意

请注意,此工具处于实验阶段,可能会报告误报。它通过禁用一些 JAX 缓存来工作,因此它会对性能产生负面影响,并且仅应在调试时使用。

示例用法

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def side_effecting(x):
...   y = x+1
...   outs.append(y)

>>> x = 1
>>> with jax.checking_leaks():
...   y = side_effecting(x)  
Traceback (most recent call last):
    ...
Exception: Leaked Trace
参数:

msg (str)