错误#
此页面列出了一些在使用 JAX 时可能遇到的错误,以及如何修复它们的代表性示例。
- class jax.errors.ConcretizationTypeError(tracer, context='')#
当 JAX Tracer 对象在需要具体值的上下文中使用时,会发生此错误(有关 Tracer 是什么,请参阅 不同类型的 JAX 值)。在某些情况下,可以通过将有问题的标记为静态来轻松修复;在其他情况下,它可能表明你的程序正在执行 JAX 的 JIT 编译模型不直接支持的操作。
示例
- 在需要静态值的地方使用了跟踪值
导致此错误的一个常见原因是,在需要静态值的地方使用了跟踪值。例如
>>> from functools import partial >>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, axis): ... return x.min(axis)
>>> func(jnp.arange(4), 0) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: axis argument to jnp.min().
这通常可以通过将有问题的参数标记为静态来修复
>>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return x.min(axis) >>> func(jnp.arange(4), 0) Array(0, dtype=int32)
- 形状取决于跟踪值
当 JIT 编译的计算中的形状取决于跟踪量中的值时,也可能会出现此类错误。例如
>>> @jit ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
这是一个与 JAX 的 JIT 编译模型不兼容的操作示例,该模型要求数组大小在编译时已知。这里,返回的数组的大小取决于 x 的内容,这样的代码不能进行 JIT 编译。
在许多情况下,可以通过修改函数中使用的逻辑来解决这个问题;例如,这里有一段代码存在类似的问题
>>> @jit ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
下面是如何用一种避免创建动态大小索引数组的方式表达相同操作的方法
>>> @jit ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)
要了解更多关于跟踪器与常规值以及具体值与抽象值之间的微妙之处,您可以阅读不同类型的 JAX 值。
- 参数:
tracer (core.Tracer)
context (str)
- class jax.errors.KeyReuseError(message)#
当 PRNG 密钥以不安全的方式重复使用时,会发生此错误。仅当 jax_debug_key_reuse 设置为 True 时,才会检查密钥重用。
这是一个导致此类错误的简单代码示例
>>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... value = jax.random.uniform(key) ... new_value = jax.random.uniform(key) ... --------------------------------------------------------------------------- KeyReuseError Traceback (most recent call last) ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
这种密钥重用是有问题的,因为 JAX PRNG 是无状态的,并且密钥必须手动拆分;有关此方面的更多信息,请参阅伪随机数教程。
- 参数:
message (str)
- jax.errors.JaxRuntimeError#
别名
XlaRuntimeError
- class jax.errors.NonConcreteBooleanIndexError(tracer)#
当程序尝试在跟踪的索引操作中使用非具体的布尔索引时,会发生此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此必须谨慎使用布尔掩码。通过布尔掩码实现的一些逻辑在
jax.jit()
函数中是根本不可能实现的;在其他情况下,逻辑可以以 JIT 兼容的方式重新表达,通常使用where()
的三参数版本。以下是可能发生此错误的一些示例。
- 通过布尔掩码构造数组
当尝试在 JIT 上下文中通过布尔掩码创建数组时,最常出现这种情况。例如
>>> import jax >>> import jax.numpy as jnp >>> @jax.jit ... def positive_values(x): ... return x[x > 0] >>> positive_values(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
此函数尝试仅返回输入数组中的正值;除非将 x 标记为静态,否则无法在编译时确定此返回数组的大小,因此此类操作无法在 JIT 编译下执行。
- 可重新表达的布尔逻辑
尽管不支持直接创建动态大小的数组,但在许多情况下,可以根据 JIT 兼容的操作来重新表达计算的逻辑。例如,这里有另一个函数由于相同的原因在 JIT 下失败
>>> @jax.jit ... def sum_of_positive(x): ... return x[x > 0].sum() >>> sum_of_positive(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
但是,在这种情况下,有问题的数组只是一个中间值,我们可以改为用 JIT 兼容的三参数版本的
jax.numpy.where()
来表达相同的逻辑>>> @jax.jit ... def sum_of_positive(x): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) Array(10, dtype=int32)
用三参数
where()
替换布尔掩码是解决此类问题的常见方法。- 将布尔值索引到 JAX 数组中
经常出现此错误的另一种情况是使用布尔索引,例如使用
.at[...].set(...)
。这是一个简单的例子>>> @jax.jit ... def manual_clip(x): ... return x.at[x < 0].set(0) >>> manual_clip(jnp.arange(-2, 2)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
此函数尝试将小于零的值设置为标量填充值。如上所述,可以通过用
where()
重新表达逻辑来解决此问题>>> @jax.jit ... def manual_clip(x): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) Array([0, 0, 0, 1], dtype=int32)
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerArrayConversionError(tracer)#
当程序尝试将 JAX Tracer 对象转换为标准 NumPy 数组时,会发生此错误(有关 Tracer 是什么,请参阅不同类型的 JAX 值)。它通常在以下几种情况之一中发生。
- 在 JAX 转换中使用非 JAX 函数
如果您尝试在 JAX 转换(
jit()
、grad()
、jax.vmap()
等)中使用非 JAX 库(如numpy
或scipy
)时,可能会发生此错误。例如>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x): ... return np.sin(x) >>> func(np.arange(4)) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[4]
在这种情况下,您可以通过使用
jax.numpy.sin()
代替numpy.sin()
来解决此问题>>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) >>> func(jnp.arange(4)) Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
另请参阅外部回调,了解从转换后的 JAX 代码回调到主机端计算的选项。
- 使用跟踪器索引 numpy 数组
如果此错误出现在涉及数组索引的行上,则可能是被索引的数组
x
是标准 numpy.ndarray,而索引idx
是跟踪的 JAX 数组。例如>>> x = np.arange(10) >>> @jit ... def func(i): ... return x[i] >>> func(0) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[0]
根据上下文,您可以通过将 numpy 数组转换为 JAX 数组来解决此问题
>>> @jit ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)
或通过将索引声明为静态参数
>>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] >>> func(0) Array(0, dtype=int32)
要了解更多关于跟踪器与常规值以及具体值与抽象值之间的微妙之处,您可以阅读不同类型的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerBoolConversionError(tracer)#
当 JAX 中跟踪的值在需要布尔值的上下文中使用时,会发生此错误(有关 Tracer 是什么,请参阅不同类型的 JAX 值)。
布尔转换可以是显式的(例如
bool(x)
)或隐式的,通过控制流的使用(例如if x > 0
或while x
),使用 Python 布尔运算符(例如z = x and y
,z = x or y
,z = not x
)或使用它们的函数(例如z = max(x, y)
,z = min(x, y)
等)。在某些情况下,可以通过将跟踪的值标记为静态来轻松解决此问题;在其他情况下,它可能表明您的程序正在执行 JAX 的 JIT 编译模型不直接支持的操作。
示例
- 在控制流中使用的跟踪值
经常出现这种情况的一种情况是在 Python 控制流中使用跟踪的值。例如
>>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, y): ... return x if x.sum() < y.sum() else y >>> func(jnp.ones(4), jnp.zeros(4)) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
我们可以将输入
x
和y
都标记为静态,但这会破坏在此处使用jax.jit()
的目的。另一种选择是用三项jax.numpy.where()
重新表达 if 语句>>> @jit ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)
对于包括循环在内的更复杂的控制流,请参阅控制流运算符。
- 对跟踪值进行控制流
导致此错误的另一个常见原因是如果您不小心跟踪了一个布尔标志。例如
>>> @jit ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
由于此处的标志
normalize
被跟踪,因此不能在 Python 控制流中使用。在这种情况下,最佳解决方案可能是将此值标记为静态>>> from functools import partial >>> @partial(jit, static_argnames=['normalize']) ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
有关
static_argnums
的更多信息,请参阅jax.jit()
的文档。- 使用非 JAX 感知函数
导致此错误的另一个常见原因是,在 JAX 代码中使用非 JAX 感知函数。例如
>>> @jit ... def func(x): ... return min(x, 0)
>>> func(2) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这种情况下,发生错误是因为 Python 的内置
min
函数与 JAX 转换不兼容。可以通过将其替换为jnp.minumum
来解决此问题>>> @jit ... def func(x): ... return jnp.minimum(x, 0)
>>> print(func(2)) 0
要了解更多关于跟踪器与常规值以及具体值与抽象值之间的微妙之处,您可以阅读不同类型的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerIntegerConversionError(tracer)#
当在需要 Python 整数的上下文中使用 JAX Tracer 对象时,可能会发生此错误(有关 Tracer 是什么,请参阅不同类型的 JAX 值)。它通常在以下几种情况中发生。
- 传递跟踪器代替整数
如果尝试将跟踪值传递给需要静态整数参数的函数,则可能会发生此错误;例如
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(4), 0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
发生这种情况时,通常的解决方法是将有问题的参数标记为静态
>>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) [Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]
另一种方法是将转换应用于封装要保护的参数的闭包,可以手动执行如下操作,也可以使用
functools.partial()
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,因此首选 static_argnums。
- 使用 Tracer 索引列表
如果尝试使用跟踪量索引 Python 列表,则可能会发生此错误。例如
>>> import jax.numpy as jnp >>> from jax import jit >>> L = [1, 2, 3] >>> @jit ... def func(i): ... return L[i] >>> func(0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
根据上下文,通常可以通过将列表转换为 JAX 数组来解决此问题
>>> @jit ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)
或通过将索引声明为静态参数
>>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i] >>> func(0) Array(1, dtype=int32, weak_type=True)
要了解更多关于跟踪器与常规值以及具体值与抽象值之间的微妙之处,您可以阅读不同类型的 JAX 值。
- 参数:
tracer (core.Tracer)
- class jax.errors.UnexpectedTracerError(msg)#
当您使用已从函数中泄漏出来的 JAX 值时,会发生此错误。泄漏值是什么意思?如果您对函数
f
使用 JAX 转换,该转换在f
之外的某个作用域中存储对中间值的引用,则该值被视为已泄漏。泄漏值是一种副作用。(阅读更多关于在 纯函数 中避免副作用的内容)当您稍后在另一个操作中使用泄漏的值时,JAX 会检测到泄漏,此时它会引发
UnexpectedTracerError
。要解决此问题,请避免副作用:如果函数计算外部作用域中需要的值,则从转换后的函数显式返回该值。具体来说,
Tracer
是 JAX 在转换期间(例如在jit()
,pmap()
,vmap()
等中)内部表示函数的中间值的方式。在转换之外遇到Tracer
意味着存在泄漏。- 泄漏值的生命周期
考虑以下示例,该示例显示了将值泄漏到外部作用域的转换后的函数
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit # 1 ... def side_effecting(x): ... y = x + 1 # 3 ... outs.append(y) # 4 >>> x = 1 >>> side_effecting(x) # 2 >>> outs[0] + 1 # 5 Traceback (most recent call last): ... UnexpectedTracerError: Encountered an unexpected tracer.
在此示例中,我们将跟踪值从内部转换后的作用域泄漏到外部作用域。当使用泄漏的值时,而不是在泄漏该值时,我们会收到
UnexpectedTracerError
。此示例还演示了泄漏值的生命周期
函数被转换(在这种情况下,由
jit()
)调用转换后的函数(启动函数的抽象跟踪,并将
x
转换为Tracer
)创建中间值
y
,该值稍后将泄漏(跟踪函数的中间值也是Tracer
)该值被泄漏(通过侧通道附加到外部作用域中的列表,从而逃脱函数)
使用泄漏的值,并引发 UnexpectedTracerError。
UnexpectedTracerError 消息会尝试通过包含有关每个阶段的信息来指向代码中的这些位置。分别
转换后的函数的名称 (
side_effecting
) 以及哪个转换启动了跟踪jit()
。重建的泄漏 Tracer 的创建位置的堆栈跟踪,其中包括调用转换后的函数的位置。(
当 Tracer 被创建时,最后的 5 个堆栈帧是...
)。从重建的堆栈跟踪中,泄漏的 Tracer 的创建代码行。
错误消息中不包含泄漏位置,因为它很难确定!JAX 只能告诉您泄漏的值是什么样子(它具有什么形状以及在何处创建)以及它泄漏到的边界(转换的名称和转换后的函数的名称)。
当前错误的堆栈跟踪指向使用该值的位置。
可以通过从转换后的函数中返回该值来解决该错误
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def not_side_effecting(x): ... y = x+1 ... return y >>> x = 1 >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. Array(3, dtype=int32, weak_type=True)
- 泄漏检查器
如上文第 2 点和第 3 点所述,JAX 显示了一个重建的堆栈跟踪,该跟踪指向泄漏值创建的位置。这是因为 JAX 仅在使用泄漏的值时引发错误,而不是在泄漏该值时引发错误。这不是引发此错误的最有用的位置,因为您需要知道 Tracer 泄漏的位置才能修复该错误。
为了使此位置更容易追踪,您可以使用泄漏检查器。启用泄漏检查器后,一旦泄漏
Tracer
,就会引发错误。(更准确地说,它将在泄漏Tracer
的转换后的函数返回时引发错误)要启用泄漏检查器,您可以使用
JAX_CHECK_TRACER_LEAKS
环境变量或with jax.checking_leaks()
上下文管理器。注意
请注意,此工具是实验性的,可能会报告误报。它通过禁用一些 JAX 缓存来工作,因此会对性能产生负面影响,并且仅应在调试时使用。
用法示例
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def side_effecting(x): ... y = x+1 ... outs.append(y) >>> x = 1 >>> with jax.checking_leaks(): ... y = side_effecting(x) Traceback (most recent call last): ... Exception: Leaked Trace
- 参数:
msg (str)