常见问题解答 (FAQ)#

我们在此处收集常见问题的答案。欢迎贡献!

jit 改变了我的函数的行为#

如果你的 Python 函数在使用 jax.jit() 后行为发生改变,可能是因为你的函数使用了全局状态或有副作用。在以下代码中,impure_func 使用了全局变量 y,并且由于 print 而产生了副作用。

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

不使用 jit 时,输出为

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

使用 jit 时,输出为

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),该函数会在 Python 解释器中执行一次,此时会发生 Inside 打印,并观察到 y 的第一个值。然后,该函数会被编译并缓存,并使用不同的 x 值执行多次,但使用的是 y 的第一个值。

延伸阅读

jit 会改变输出的精确数值#

有时用户会惊讶地发现,用 jit() 包装一个函数会改变函数的输出。例如:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

输出的这种细微差异来自 XLA 编译器中的优化:在编译期间,XLA 有时会重新排列或省略某些操作,以使整体计算更高效。

在本例中,XLA 利用对数的性质,将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个数学上相同的表达式,可以比原始表达式更有效地计算。输出的差异来自于浮点运算只是对实数数学的近似,因此计算相同表达式的不同方式可能会产生细微不同的结果。

在其他情况下,XLA 的优化可能会导致更显着的差异。考虑以下示例:

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非 JIT 编译的逐操作模式下,结果是 inf,因为 jnp.exp(x) 溢出并返回 inf。然而,在 JIT 下,XLA 识别到 logexp 的反函数,并从已编译的函数中删除这些操作,只返回输入。在这种情况下,JIT 编译产生了更精确的实数结果的浮点近似。

不幸的是,XLA 的代数简化完整列表没有详细记录,但如果您熟悉 C++ 并且好奇 XLA 编译器进行了哪些类型的优化,您可以在源代码中查看它们:algebraic_simplifier.cc

jit 修饰的函数编译速度非常慢#

如果你的 jit 修饰的函数第一次调用时需要数十秒(或更长时间)才能运行,但在再次调用时执行速度很快,则表明 JAX 花费了很长时间来跟踪或编译你的代码。

这通常表示调用你的函数会在 JAX 的内部表示中生成大量代码,通常是因为它大量使用了 Python 控制流,例如 for 循环。对于少量的循环迭代,Python 可以接受,但是如果你需要大量的循环迭代,你应该重写你的代码以使用 JAX 的 结构化控制流原语(例如 lax.scan()),或者避免用 jit 包装循环(你仍然可以在循环内部使用 jit 修饰的函数)。

如果你不确定这是否是问题所在,你可以尝试在你的函数上运行 jax.make_jaxpr()。如果输出有数百或数千行长,则可以预期编译速度会很慢。

有时,如何重写代码以避免 Python 循环并不明显,因为你的代码使用了许多具有不同形状的数组。在这种情况下,推荐的解决方案是使用诸如 jax.numpy.where() 之类的函数,以便对具有固定形状的填充数组执行计算。

如果你的函数由于其他原因编译速度慢,请在 GitHub 上打开一个问题。

如何将 jit 与方法一起使用?#

jax.jit() 的大多数示例都涉及修饰独立的 Python 函数,但在类中修饰一个方法会引入一些复杂性。例如,考虑以下简单的类,我们在其中对一个方法使用了标准的 jit() 注释:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

但是,当你尝试调用此方法时,此方法将导致错误:

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于该函数的第一个参数是 self,其类型为 CustomClass,而 JAX 不知道如何处理此类型。在这种情况下,我们可以使用三种基本策略,我们将在下面讨论它们。

策略 1:JIT 编译的辅助函数#

最直接的方法是创建一个类外部的辅助函数,该函数可以像通常的方式一样进行 JIT 修饰。例如:

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将按预期工作:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的好处是它简单、明确,并且避免了教 JAX 如何处理 CustomClass 类型对象的需求。但是,你可能希望将所有方法逻辑都放在同一位置。

策略 2:将 self 标记为静态#

另一种常见的模式是使用 static_argnumsself 参数标记为静态。但是,必须小心执行此操作,以避免出现意外结果。你可能会尝试简单地执行此操作:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果你调用该方法,它将不再引发错误:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

但是,这里有一个陷阱:如果你在第一次方法调用后改变了对象,则后续方法调用可能会返回不正确的结果:

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

这是为什么呢?当你将一个对象标记为静态时,它实际上将被用作 JIT 内部编译缓存中的字典键,这意味着其哈希值(即 hash(obj))相等性(即 obj1 == obj2)和对象标识(即 obj1 is obj2)将被假定为具有一致的行为。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道一个更改后的对象应该触发重新编译。

你可以通过为你的对象定义适当的 __hash____eq__ 方法来部分解决此问题;例如:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(有关重写 __hash__ 时的要求的更多讨论,请参阅 object.__hash__() 文档)。

只要你从不改变你的对象,这应该可以与 JIT 和其他转换一起正常工作。用作哈希键的对象的更改会导致一些微妙的问题,这就是为什么例如可变的 Python 容器(例如 dictlist)不定义 __hash__,而它们的不可变对应项(例如 tuple)定义 __hash__

如果你的类依赖于就地更改(例如在方法内设置 self.attr = ...),则你的对象实际上不是“静态的”,将其标记为静态可能会导致问题。幸运的是,在这种情况下还有另一种选择。

策略 3:使 CustomClass 成为 PyTree#

正确 JIT 编译类方法的最灵活方法是将该类型注册为自定义 PyTree 对象;请参阅 扩展 pytrees。这使你可以精确指定应将类的哪些组件视为静态,哪些应视为动态。它的外观如下:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这当然更复杂,但它解决了与上面使用的更简单方法相关的所有问题:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要你的 tree_flattentree_unflatten 函数正确处理了类中的所有相关属性,你应该可以直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注释。

控制设备上的数据和计算放置#

我们首先来看一下 JAX 中数据和计算放置的原则。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1) 数据所在的设备;2) 数据是否提交到设备(有时数据被称为在设备上)。

默认情况下,JAX 数组以未提交状态放置在默认设备上(jax.devices()[0]),默认情况下这是第一个 GPU 或 TPU。如果没有 GPU 或 TPU,jax.devices()[0] 则是 CPU。默认设备可以使用 jax.default_device() 上下文管理器临时覆盖,或者通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 为“cpu”、“gpu”或“tpu”来为整个进程设置(JAX_PLATFORMS 也可以是平台列表,它决定了哪些平台按优先级顺序可用)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未提交数据的计算在默认设备上执行,结果在默认设备上也是未提交的。

也可以使用带有 device 参数的 jax.device_put() 将数据显式放置在设备上,在这种情况下,数据将提交到设备。

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已提交输入的计算将在已提交的设备上进行,结果将提交到同一设备。在提交到多个设备的参数上调用操作会引发错误。

您也可以使用不带 device 参数的 jax.device_put()。如果数据已在某个设备上(已提交或未提交),则保持原样。如果数据不在任何设备上(即,它是常规的 Python 或 NumPy 值),则将其以未提交状态放置在默认设备上。

Jitted 函数的行为类似于任何其他基本操作 - 它们将跟随数据,如果对提交到多个设备的数据调用,则会显示错误。

(在 2021 年 3 月的 PR #6002 之前,数组常量的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似的操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动它。但是,为了简化实现,此优化已被删除。)

(从 2020 年 4 月起,jax.jit() 有一个影响设备放置的 device 参数。该参数是实验性的,可能会被删除或更改,不建议使用。)

有关详细示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

基准测试 JAX 代码#

您刚刚将一个棘手的函数从 NumPy/SciPy 移植到 JAX。这实际上加快了速度吗?

在测量使用 JAX 的代码的速度时,请记住与 NumPy 的这些重要区别

  1. JAX 代码是即时 (JIT) 编译的。大多数用 JAX 编写的代码都可以以支持 JIT 编译的方式编写,这可以使其运行快得多(请参阅 To JIT or not to JIT)。要从 JAX 获得最大性能,您应该在最外层的函数调用上应用 jax.jit()

    请记住,第一次运行 JAX 代码时,速度会较慢,因为它正在被编译。即使您在自己的代码中不使用 jit,也是如此,因为 JAX 的内置函数也是 JIT 编译的。

  2. JAX 具有异步调度。这意味着您需要调用 .block_until_ready() 以确保计算实际发生(请参阅 异步调度)。

  3. JAX 默认只使用 32 位数据类型。您可能希望在 NumPy 中显式使用 32 位数据类型,或在 JAX 中启用 64 位数据类型(请参阅 双精度 (64 位))以进行公平比较。

  4. 在 CPU 和加速器之间传输数据需要时间。如果您只想测量评估函数需要多长时间,您可能需要先将数据传输到要运行它的设备上(请参阅 控制设备上的数据和计算放置)。

这是一个示例,说明如何将所有这些技巧组合到一个微基准测试中,以比较 JAX 与 NumPy,利用 IPython 方便的 %time 和 %timeit 魔术命令

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

Colab 中使用 GPU 运行时,我们看到

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒

  • JAX 需要 1.26 毫秒将 NumPy 数组复制到 GPU 上

  • JAX 需要 193 毫秒来编译函数

  • JAX 在 GPU 上每次评估需要 485 微秒

在这种情况下,我们看到一旦数据传输完成并且函数编译完成,GPU 上的 JAX 对于重复评估来说快了大约 30 倍。

这是一个公平的比较吗?也许吧。最终重要的是运行完整应用程序的性能,这不可避免地包括一些数据传输和编译。此外,我们小心地选择了足够大的数组 (1000x1000) 和足够密集的计算(@ 运算符执行矩阵-矩阵乘法),以摊销 JAX/加速器与 NumPy/CPU 相比增加的开销。例如,如果我们切换此示例以使用 10x10 输入,则 JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。

JAX 比 NumPy 快吗?#

用户经常试图用这样的基准测试来回答的一个问题是 JAX 是否比 NumPy 快;由于这两个包的差异,没有简单的答案。

一般来说

  • NumPy 操作是急切地、同步地执行的,并且仅在 CPU 上执行。

  • JAX 操作可以在急切地执行,也可以在编译后执行(如果在 jit() 内部);它们是异步调度的(请参阅 异步调度);并且它们可以在 CPU、GPU 或 TPU 上执行,每个设备的性能特征都截然不同且不断发展。

这些架构差异使得 NumPy 和 JAX 之间有意义的直接基准比较变得困难。

此外,这些差异导致了软件包之间不同的工程重点:例如,NumPy 在降低单个数组操作的每次调用调度开销方面付出了巨大的努力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有几种方法可以避免调度开销(例如,JIT 编译、异步调度、批处理转换等),因此减少每次调用的开销不是优先事项。

考虑到所有这些,总而言之:如果您正在对 CPU 上的单个数组操作进行微基准测试,那么通常可以预期 NumPy 的性能会优于 JAX,因为它每次操作的调度开销较低。如果您在 GPU 或 TPU 上运行代码,或者正在对 CPU 上更复杂的 JIT 编译操作序列进行基准测试,那么通常可以预期 JAX 的性能会优于 NumPy。

不同类型的 JAX 值#

在转换函数的过程中,JAX 将一些函数参数替换为特殊的跟踪器值。

如果您使用 print 语句,您可能会看到这一点

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

上面的代码确实返回了正确的值 1.,但它还为 x 的值打印了 Traced<ShapedArray(float32[])>。通常,JAX 会以透明的方式在内部处理这些跟踪器值,例如,在用于实现 jax.numpy 函数的数值 JAX 原语中。这就是为什么 jnp.cos 在上面的示例中起作用的原因。

更准确地说,为 JAX 转换函数的参数引入一个跟踪器值,除了由特殊参数标识的参数,例如 jax.jit()static_argnumsjax.pmap()static_broadcasted_argnums。通常,涉及至少一个跟踪器值的计算会生成一个跟踪器值。除了跟踪器值之外,还有常规的 Python 值:在 JAX 转换之外计算的值,或来自上述某些 JAX 转换的静态参数的值,或仅从其他常规 Python 值计算的值。这些值是在没有 JAX 转换的情况下在任何地方使用的值。

跟踪器值携带一个抽象值,例如,带有关于数组的形状和数据类型信息的 ShapedArray。我们在这里将这些跟踪器称为抽象跟踪器。一些跟踪器,例如为自动微分转换的参数引入的跟踪器,携带 ConcreteArray 抽象值,这些值实际上包括常规的数组数据,并且用于,例如,解析条件。我们在这里将这些跟踪器称为具体跟踪器。从这些具体跟踪器计算出的跟踪器值,可能与常规值结合使用,会产生具体跟踪器。一个具体值要么是一个常规值,要么是一个具体跟踪器。

通常情况下,从跟踪器值计算得到的值本身也是跟踪器值。 只有极少数例外情况,当计算完全可以使用跟踪器携带的抽象值完成时,结果可以是常规值。 例如,获取具有 ShapedArray 抽象值的跟踪器的形状。 另一个例子是当显式将具体跟踪器值转换为常规类型时,例如 int(x)x.astype(float)。 另一种情况是 bool(x),当具体化成为可能时,它会生成一个 Python 布尔值。这种情况尤其突出,因为它在控制流中经常出现。

以下是转换如何引入抽象或具体跟踪器:

  • jax.jit():为所有位置参数引入抽象跟踪器,但由 static_argnums 指定的参数除外,这些参数保持为常规值。

  • jax.pmap():为所有位置参数引入抽象跟踪器,但由 static_broadcasted_argnums 指定的参数除外。

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象跟踪器

  • jax.jvp()jax.grad() 为所有位置参数引入具体跟踪器。例外情况是当这些转换在外层转换内部时,并且实际参数本身是抽象跟踪器;在这种情况下,自动微分转换引入的跟踪器也是抽象跟踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时会引入抽象跟踪器,无论是否正在进行 JAX 转换。

当您的代码只能在常规 Python 值上操作时,例如基于数据的条件控制流的代码,所有这些都是相关的。

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们想应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是因为布尔表达式 y >= 1.,它需要具体值(常规值或跟踪器)。如果我们显式地写 bool(y >= 1.)int(y)float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 可以工作,因为 jax.grad() 使用具体跟踪器,并使用 y 的具体值解析条件。

缓冲区捐赠#

当 JAX 执行计算时,它会使用设备上的缓冲区来存储所有输入和输出。如果您知道某个输入在计算后不需要,并且它与其中一个输出的形状和元素类型匹配,则可以指定您希望将相应的输入缓冲区捐赠以保存输出。这将减少执行所需的内存,减少量等于捐赠缓冲区的大小。

如果您有如下模式,则可以使用缓冲区捐赠:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

您可以将其视为对不可变的 JAX 数组进行内存高效的函数式更新的一种方式。在计算的边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界,您需要向 XLA 保证在调用捐赠函数后不会使用捐赠的输入缓冲区。

您可以通过使用函数 jax.jit()jax.pjit()jax.pmap()donate_argnums 参数来实现。此参数是位置参数列表的索引序列(从 0 开始)。

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,当使用关键字参数调用函数时,目前不起作用!以下代码不会捐赠任何缓冲区:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果捐赠缓冲区的参数是 pytree,则会捐赠其所有组件的缓冲区。

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠随后在计算中使用的缓冲区,并且 JAX 会给出错误,因为在捐赠后 y 的缓冲区已失效。

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果未使用的捐赠缓冲区过多(例如,由于捐赠的缓冲区多于输出可用的缓冲区),您会收到警告。

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果没有输出的形状与捐赠匹配,则捐赠也可能未使用。

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

使用 where 时,梯度包含 NaN#

如果使用 where 定义函数以避免未定义的值,如果不小心,您可能会在反向微分时获得 NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

简短的解释是,在 grad 计算期间,对应于未定义的 jnp.log(x) 的伴随是 NaN,它会被累积到 jnp.where 的伴随中。编写此类函数的正确方法是确保在部分定义的函数内部有一个 jnp.where,以确保伴随始终是有限的。

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

除了原始的 jnp.where 之外,可能还需要内部的 jnp.where,例如:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

延伸阅读

为什么基于排序顺序的函数的梯度为零?#

如果您定义一个使用依赖于输入相对顺序的操作(例如 maxgreaterargsort 等)处理输入的函数,那么您可能会惊讶地发现梯度处处为零。 这是一个示例,我们定义 f(x) 为一个阶跃函数,当 x 为负数时返回 0,当 x 为正数时返回 1

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度处处为零的事实乍一看可能会令人困惑:毕竟,输出确实会响应输入而变化,那么梯度怎么会为零呢?但是,在这种情况下,零是正确的结果。

为什么会这样?请记住,微分测量的是给定 x 的无穷小变化时 f 的变化。对于 x=1.0f 返回 1.0。如果我们扰动 x 使其略大或略小,这不会改变输出,因此根据定义,grad(f)(1.0) 应为零。对于所有大于零的 f 值,此逻辑都适用:对输入进行无穷小的扰动不会改变输出,因此梯度为零。同样,对于所有小于零的 x 值,输出为零。扰动 x 不会改变此输出,因此梯度为零。这给我们留下了 x=0 的棘手情况。当然,如果您向上扰动 x,它会改变输出,但这有问题:x 的无穷小变化会导致函数值发生有限的变化,这意味着梯度未定义。幸运的是,我们还有另一种方法来衡量这种情况下的梯度:我们向下扰动函数,在这种情况下,输出不会改变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个是已定义的,另一个是未定义的,则我们使用已定义的那个。根据这种梯度的定义,从数学和数值上来说,该函数的梯度处处为零。

问题源于我们的函数在 x = 0 处存在不连续性。 这里的 f 本质上是一个 Heaviside 阶跃函数,我们可以使用 Sigmoid 函数 作为平滑的替代。当 x 远离零时,Sigmoid 函数近似等于 Heaviside 函数,但它用平滑、可微分的曲线代替了 x = 0 处的不连续性。 由于使用了 jax.nn.sigmoid(),我们得到了一个具有明确梯度的类似计算。

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

jax.nn 子模块还具有其他常见基于排名的函数的平滑版本,例如,jax.nn.softmax() 可以替代 jax.numpy.argmax() 的使用,jax.nn.soft_sign() 可以替代 jax.numpy.sign() 的使用,jax.nn.softplus()jax.nn.squareplus() 可以替代 jax.nn.relu() 的使用,等等。

如何将 JAX Tracer 转换为 NumPy 数组?#

当在运行时检查转换后的 JAX 函数时,你会发现数组值被 Tracer 对象替换了。

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这会打印以下内容

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这种 Tracer 转换回普通的 NumPy 数组。简而言之,**不可能将 Tracer 转换为 NumPy 数组**,因为 Tracer 是具有给定形状和 dtype 的*每个可能*值的抽象表示,而 numpy 数组是该抽象类的具体成员。有关 Tracer 如何在 JAX 转换的上下文中工作的更多讨论,请参阅 JIT 机制

将 Tracers 转换回数组的问题通常出现在另一个目标的上下文中,该目标与在运行时访问计算中的中间值有关。 例如

  • 如果你希望在运行时打印跟踪的值以进行调试,你可以考虑使用 jax.debug.print()

  • 如果你希望在转换后的 JAX 函数中调用非 JAX 代码,你可以考虑使用 jax.pure_callback(),其示例可在 纯回调示例中找到。

  • 如果你希望在运行时输入或输出数组缓冲区(例如,从文件加载数据或将数组内容记录到磁盘),你可以考虑使用 jax.experimental.io_callback(),其示例可在 IO 回调示例中找到。

有关运行时回调及其使用示例的更多信息,请参阅 JAX 中的外部回调

为什么某些 CUDA 库无法加载/初始化?#

在解析动态库时,JAX 使用通常的 动态链接器搜索模式。JAX 将 RPATH 设置为指向 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装,则优先使用它们。如果 ld.so 无法在其通常的搜索路径中找到你的 CUDA 运行时库,则你必须在 LD_LIBRARY_PATH 中显式包含这些库的路径。确保你的 CUDA 文件可被发现的最简单方法是直接安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

有时,即使你已确保你的运行时库可被发现,仍然可能存在加载或初始化它们的问题。此类问题的常见原因是运行时 CUDA 库初始化内存不足。有时会出现这种情况,因为 JAX 会预先分配过大的当前可用设备内存块以加快执行速度,有时会导致运行时 CUDA 库初始化可用的内存不足。

当运行多个 JAX 实例、与执行自己预分配的 TensorFlow 并行运行 JAX,或者当 GPU 被其他进程大量使用时,这种情况尤其可能发生。如有疑问,请尝试通过减少默认值 .75XLA_PYTHON_CLIENT_MEM_FRACTION,或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false 来减少预分配再次运行该程序。有关更多详细信息,请参阅 JAX GPU 内存分配页面。