常见问题 (FAQ)#

我们正在此处收集常见问题的答案。欢迎投稿!

jit 改变了我的函数的行为#

如果您有一个在使用 jax.jit() 后行为发生变化的 Python 函数,则您的函数可能使用了全局状态,或者具有副作用。在以下代码中,impure_func 使用全局 y 并且由于 print 而具有副作用

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

没有 jit,输出为

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

而使用 jit,输出为

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),函数会使用 Python 解释器执行一次,此时会发生 Inside 打印,并且会观察到 y 的第一个值。然后,函数会被编译和缓存,并使用 x 的不同值多次执行,但 y 的第一个值保持不变。

其他阅读材料

jit 会改变输出的精确数值#

有时用户会对这样一个事实感到惊讶:用 jit() 包装函数可能会改变函数的输出。例如

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

输出的细微差异来自 XLA 编译器内部的优化:在编译期间,XLA 有时会重新排列或省略某些操作,以使整体计算更有效率。

在这种情况下,XLA 利用对数的性质将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个数学上等价的表达式,其计算效率比原始表达式更高。输出的差异来自这样一个事实:浮点运算只是对实数数学的近似,因此计算同一表达式的不同方法可能会有细微的差异。

其他时候,XLA 的优化可能会导致更大的差异。考虑以下示例

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非 JIT 编译的逐操作模式下,结果为 inf,因为 jnp.exp(x) 溢出并返回 inf。然而,在 JIT 下,XLA 识别出 logexp 的逆运算,并从编译后的函数中删除这些操作,只返回输入。在这种情况下,JIT 编译产生了对实数结果更准确的浮点近似。

不幸的是,XLA 的代数简化完整列表没有很好的文档记录,但如果您熟悉 C++ 并且好奇 XLA 编译器会进行哪种类型的优化,您可以在源代码中查看它们: algebraic_simplifier.cc

jit 装饰的函数编译速度非常慢#

如果您的 jit 装饰的函数第一次调用需要数十秒(甚至更长时间)才能运行,但在再次调用时执行速度很快,则表示 JAX 花费了很长时间来追踪或编译您的代码。

这通常表示调用您的函数会在 JAX 的内部表示中生成大量代码,通常是因为它大量使用了 Python 控制流,例如 for 循环。对于少量循环迭代,Python 可以胜任,但如果您需要大量循环迭代,则应重写代码以利用 JAX 的 结构化控制流原语(例如 lax.scan())或避免用 jit 包装循环(您仍然可以在循环内部使用 jit 装饰的函数)。

如果您不确定这是否是问题所在,您可以尝试在您的函数上运行 jax.make_jaxpr()。如果输出有数百或数千行,则可以预期编译速度会很慢。

有时不清楚如何重写代码以避免 Python 循环,因为您的代码使用了许多形状不同的数组。在这种情况下,建议的解决方案是使用诸如 jax.numpy.where() 之类的函数,在具有固定形状的填充数组上进行计算。

如果您的函数由于其他原因编译速度缓慢,请在 GitHub 上提交问题。

如何在方法中使用 jit#

jax.jit() 的大多数示例都涉及装饰独立的 Python 函数,但在类中装饰方法会带来一些复杂性。例如,考虑以下简单的类,我们在其中对方法使用了标准的 jit() 注解

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

但是,当您尝试调用此方法时,此方法会导致错误

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于函数的第一个参数是 self,其类型为 CustomClass,而 JAX 不知道如何处理此类型。在这种情况下,我们可以使用三种基本策略,我们将在下面讨论它们。

策略 1:JIT 编译的辅助函数#

最直接的方法是创建一个位于类外部的辅助函数,该函数可以以正常方式进行 JIT 装饰。例如

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将按预期工作

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的好处在于它简单、明确,并且避免了需要教 JAX 如何处理类型为 CustomClass 的对象。但是,您可能希望将所有方法逻辑保存在同一位置。

策略 2:将 self 标记为静态#

另一种常见模式是使用 static_argnumsself 参数标记为静态。但这必须谨慎操作以避免意外结果。您可能会想简单地这样做

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果您调用该方法,它将不再引发错误

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

但是,有一个陷阱:如果您在第一次方法调用后修改了对象,则后续方法调用可能会返回不正确的结果

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

为什么会这样?当您将对象标记为静态时,它实际上将被用作 JIT 内部编译缓存中的字典键,这意味着它的哈希值(即 hash(obj))相等性(即 obj1 == obj2)和对象标识(即 obj1 is obj2)将被假定具有一致的行为。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道已修改的对象应该触发重新编译。

您可以通过为您的对象定义合适的 __hash____eq__ 方法来部分解决此问题;例如

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(有关覆盖 __hash__ 时要求的更多讨论,请参阅 object.__hash__() 文档)。

只要您从不修改对象,这应该与 JIT 和其他转换一起正常工作。用作哈希键的可变对象的修改会导致一些微妙的问题,这就是为什么例如可变 Python 容器(例如 dictlist)不定义 __hash__,而它们的不可变对应物(例如 tuple)定义了 __hash__

如果您的类依赖于就地修改(例如在其方法中设置 self.attr = ...),那么您的对象实际上不是“静态的”,将其标记为静态可能会导致问题。幸运的是,在这种情况下还有另一个选项。

策略 3:将 CustomClass 设为 PyTree#

正确 JIT 编译类方法最灵活的方法是将类型注册为自定义 PyTree 对象;请参阅 扩展 pytrees。这允许您精确指定类的哪些组件应被视为静态,哪些应被视为动态。以下是如何实现的

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这当然更复杂,但它解决了上述更简单方法相关的所有问题

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要您的 tree_flattentree_unflatten 函数正确处理类中的所有相关属性,您应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注释。

控制设备上的数据和计算放置#

让我们首先了解 JAX 中数据和计算放置的原理。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1) 数据驻留的设备;2) 数据是否已**提交**到设备(数据有时被称为对设备具有粘性)。

默认情况下,JAX 数组在默认设备(jax.devices()[0])上未提交放置,默认设备默认为第一个 GPU 或 TPU。如果不存在 GPU 或 TPU,则 jax.devices()[0] 为 CPU。默认设备可以使用 jax.default_device() 上下文管理器临时覆盖,或通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 为“cpu”、“gpu”或“tpu”来设置整个进程(JAX_PLATFORMS 也可以是平台列表,它确定按优先级顺序可用的平台)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未提交数据的计算在默认设备上执行,结果在默认设备上未提交。

数据也可以使用 jax.device_put()device 参数显式放置在设备上,在这种情况下,数据会**提交**到设备

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已提交输入的计算将在已提交的设备上执行,结果也将提交到同一设备上。对提交到多个设备的参数调用操作将引发错误。

您也可以使用 jax.device_put() 而不带 device 参数。如果数据已在设备上(已提交或未提交),则保持原样。如果数据不在任何设备上,即它是常规的 Python 或 NumPy 值,则将其未提交放置在默认设备上。

JIT 函数的行为与任何其他基本操作相同——它们将遵循数据,如果在提交到多个设备上的数据上调用,则会显示错误。

(在 2021 年 3 月的 PR #6002 之前,数组常量的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动它。但此优化已删除,以简化实现。)

(截至 2020 年 4 月,jax.jit() 具有影响设备放置的 device 参数。该参数是实验性的,可能会被删除或更改,不建议使用。)

有关详细示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

JAX 代码基准测试#

您刚刚将一个棘手的函数从 NumPy/SciPy 移植到 JAX。这是否真的加快了速度?

在使用 JAX 测量代码速度时,请记住与 NumPy 的这些重要区别

  1. **JAX 代码是即时 (JIT) 编译的。**大多数用 JAX 编写的代码都可以以支持 JIT 编译的方式编写,这可以使其运行得*快得多*(请参阅 是否使用 JIT)。要从 JAX 获得最大性能,您应该在最外层的函数调用上应用 jax.jit()

    请记住,第一次运行 JAX 代码时,速度会比较慢,因为此时正在进行编译。即使您在自己的代码中不使用 jit,情况也是如此,因为 JAX 的内置函数也是 JIT 编译的。

  2. **JAX 具有异步调度。**这意味着您需要调用 .block_until_ready() 以确保计算已实际发生(请参阅 异步调度)。

  3. **JAX 默认只使用 32 位数据类型。**您可能希望在 NumPy 中显式使用 32 位数据类型,或在 JAX 中启用 64 位数据类型(请参阅 双精度 (64 位)),以便进行公平比较。

  4. **在 CPU 和加速器之间传输数据需要时间。**如果您只想测量评估函数需要多长时间,则可能需要先将数据传输到要运行它的设备上(请参阅 控制设备上的数据和计算放置)。

以下是如何将所有这些技巧组合到一个用于比较 JAX 与 NumPy 的微基准测试中的示例,利用 IPython 方便的 %time 和 %timeit 魔法命令

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

Colab 中使用 GPU 运行时,我们将看到

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒

  • JAX 将 NumPy 数组复制到 GPU 上需要 1.26 毫秒

  • JAX 编译函数需要 193 毫秒

  • JAX 在 GPU 上每次评估需要 485 微秒

在这种情况下,我们看到,一旦数据传输完成并且函数编译完成,JAX 在 GPU 上对重复评估的速度大约快 30 倍。

这是一个公平的比较吗?也许吧。最终重要的性能是运行完整的应用程序,这不可避免地包括一定数量的数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符正在执行矩阵-矩阵乘法),以摊销 JAX/加速器与 NumPy/CPU 的增加的开销。例如,如果我们将此示例切换为使用 10x10 输入,则 JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 微秒对 10 微秒)。

JAX 是否比 NumPy 快?#

用户经常使用此类基准测试试图回答的一个问题是 JAX 是否比 NumPy 快;由于这两个软件包之间的差异,没有简单的答案。

广义上讲

  • NumPy 操作是急切、同步执行的,并且仅在 CPU 上执行。

  • JAX 操作可能在编译前或编译后执行(如果在 jit() 内部);它们是异步调度的(请参阅 异步调度);并且它们可以在 CPU、GPU 或 TPU 上执行,每种设备具有截然不同且不断变化的性能特征。

这些架构差异使得在 NumPy 和 JAX 之间进行有意义的直接基准比较变得困难。

此外,这些差异导致了这两个软件包之间不同的工程重点:例如,NumPy 已投入大量精力来减少单个数组操作的每次调用调度开销,因为在 NumPy 的计算模型中,无法避免此开销。另一方面,JAX 有几种方法可以避免调度开销(例如 JIT 编译、异步调度、批处理转换等),因此降低每次调用开销的优先级较低。

牢记所有这些,总而言之:如果您正在对 CPU 上的单个数组操作进行微基准测试,则通常可以预期 NumPy 会由于其较低的每次操作调度开销而优于 JAX。如果您在 GPU 或 TPU 上运行代码,或者正在对 CPU 上更复杂的 JIT 编译的操作序列进行基准测试,则通常可以预期 JAX 会优于 NumPy。

不同类型的 JAX 值#

在转换函数的过程中,JAX 会用特殊的跟踪器值替换某些函数参数。

如果您使用 print 语句,则可以看到这一点

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

以上代码确实返回了正确的值 1.,但它也为 x 的值打印了 Traced<ShapedArray(float32[])>。通常,JAX 以透明的方式在内部处理这些跟踪器值,例如,在用于实现 jax.numpy 函数的数值 JAX 基本操作中。这就是为什么 jnp.cos 在上面的示例中有效的原因。

更准确地说,跟踪器值是为 JAX 转换函数的参数引入的,除了由特殊参数(例如 jax.jit()static_argnumsjax.pmap()static_broadcasted_argnums)标识的参数。通常,涉及至少一个跟踪器值的计算将产生一个跟踪器值。除了跟踪器值之外,还有常规 Python 值:在 JAX 转换之外计算的值,或来自某些 JAX 转换的上述静态参数,或仅从其他常规 Python 值计算的值。这些是在没有 JAX 转换的情况下随处使用的值。

跟踪器值携带一个抽象值,例如,带有有关数组形状和数据类型的信息的 ShapedArray。在这里,我们将此类跟踪器称为抽象跟踪器。某些跟踪器,例如为自动微分转换的参数引入的那些跟踪器,携带 ConcreteArray 抽象值,这些值实际上包含常规数组数据,并用于例如解析条件。在这里,我们将此类跟踪器称为具体跟踪器。从这些具体跟踪器(可能与常规值结合)计算出的跟踪器值会导致具体跟踪器。具体值是常规值或具体跟踪器。

大多数情况下,从跟踪器值计算出的值本身就是跟踪器值。只有极少数例外情况,当计算可以使用跟踪器携带的抽象值完全完成时,在这种情况下,结果可以是常规值。例如,使用带有 ShapedArray 抽象值的跟踪器获取形状。另一个示例是将具体跟踪器值显式转换为常规类型,例如 int(x)x.astype(float)。另一种情况是 bool(x),当具体性允许时,它会产生一个 Python 布尔值。这种情况尤其突出,因为它在控制流中经常出现。

以下是转换如何引入抽象或具体跟踪器

  • jax.jit():为所有位置参数引入抽象跟踪器,除了由 static_argnums 表示的参数,这些参数保持常规值。

  • jax.pmap():为所有位置参数引入抽象跟踪器,除了由 static_broadcasted_argnums 表示的参数。

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象跟踪器

  • jax.jvp()jax.grad() 为所有位置参数引入具体跟踪器。一个例外是当这些转换在外部转换中并且实际参数本身是抽象跟踪器时;在这种情况下,自动微分转换引入的跟踪器也是抽象跟踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时都会引入**抽象追踪器**,无论是否正在进行 JAX 变换。

所有这些都与你的代码只能操作普通 Python 值相关,例如基于数据的条件控制流代码。

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们想应用jax.jit(),我们必须确保指定static_argnums=1以确保y保持普通值。这是因为布尔表达式y >= 1.需要具体的值(普通值或追踪器)。如果我们显式地写bool(y >= 1.)int(y)float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.)可以工作,因为jax.grad()使用具体追踪器,并使用y的具体值来解决条件。

缓冲区捐赠#

当 JAX 执行计算时,它会为所有输入和输出使用设备上的缓冲区。如果你知道某个输入在计算后不再需要,并且如果它与某个输出的形状和元素类型匹配,你可以指定希望将相应的输入缓冲区捐赠来保存输出。这将减少执行所需的内存,减少捐赠缓冲区的大小。

如果你有类似以下模式的内容,你可以使用缓冲区捐赠。

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

你可以将其视为一种对你的不可变 JAX 数组进行内存高效的功能更新的方法。在计算的边界内,XLA 可以为你进行此优化,但在 jit/pmap 边界,你需要保证 XLA 在调用捐赠函数后不会使用捐赠的输入缓冲区。

你可以通过使用函数jax.jit()jax.pjit()jax.pmap()donate_argnums参数来实现这一点。此参数是位置参数列表(从 0 开始)的索引序列。

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,当使用关键字参数调用你的函数时,目前这不起作用!以下代码不会捐赠任何缓冲区。

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果缓冲区被捐赠的参数是 pytree,则其所有组件的缓冲区都会被捐赠。

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠随后在计算中使用的缓冲区,并且 JAX 会报错,因为y的缓冲区在捐赠后已失效。

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果未使用的缓冲区被捐赠,例如,捐赠的缓冲区数量超过了输出可以使用的数量,则会发出警告。

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果没有任何输出的形状与捐赠匹配,则捐赠也可能未被使用。

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

使用where时,梯度包含NaN#

如果你使用where定义一个函数来避免未定义的值,如果你不小心,你可能会得到反向微分的NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

简而言之,在grad计算期间,与未定义的jnp.log(x)对应的伴随量是NaN,它会被累加到jnp.where的伴随量。编写此类函数的正确方法是确保在部分定义的函数内部有一个jnp.where,以确保伴随量始终是有限的。

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

除了原始的jnp.where之外,可能还需要内部的jnp.where,例如:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

其他阅读材料

为什么基于排序顺序的函数的梯度为零?#

如果你定义了一个使用依赖于输入相对顺序的操作(例如maxgreaterargsort等)来处理输入的函数,那么你可能会惊讶地发现梯度处处为零。这是一个例子,我们定义f(x)为一个阶跃函数,当x为负时返回0,当x为正时返回1

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度处处为零的事实乍一看可能令人困惑:毕竟,输出确实会根据输入而变化,那么梯度怎么会为零呢?然而,在这种情况下,零恰好是正确的结果。

为什么会这样?请记住,微分测量的是给定x的无穷小变化时f的变化。对于x=1.0f返回1.0。如果我们扰动x使其略微变大或变小,这不会改变输出,因此根据定义,grad(f)(1.0)应该为零。同样的逻辑适用于大于零的所有f值:无穷小地扰动输入不会改变输出,因此梯度为零。类似地,对于所有小于零的x值,输出为零。扰动x不会改变此输出,因此梯度为零。这让我们陷入了棘手的x=0情况。当然,如果你向上扰动x,它会改变输出,但这存在问题:x的无穷小变化会导致函数值的有限变化,这意味着梯度未定义。幸运的是,我们还有另一种方法来测量这种情况下的梯度:我们向下扰动函数,在这种情况下,输出不会改变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但其中一个已定义而另一个未定义,则我们使用已定义的那个。根据此梯度的定义,从数学和数值上讲,此函数的梯度处处为零。

问题源于我们的函数在x = 0处存在不连续性。我们这里的f本质上是一个Heaviside 阶跃函数,我们可以使用Sigmoid 函数作为平滑的替代。当x远离零时,sigmoid 近似等于 heaviside 函数,但用平滑的可微曲线替换了x = 0处的不连续性。由于使用了jax.nn.sigmoid(),我们得到了一个具有良好定义梯度的类似计算。

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

jax.nn 子模块还具有其他常见基于秩的函数的平滑版本,例如jax.nn.softmax() 可以替换jax.numpy.argmax()的使用,jax.nn.soft_sign() 可以替换jax.numpy.sign()的使用,jax.nn.softplus()jax.nn.squareplus() 可以替换jax.nn.relu()的使用,等等。

如何将 JAX 追踪器转换为 NumPy 数组?#

在运行时检查转换后的 JAX 函数时,你会发现数组值被Tracer对象替换。

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这将打印以下内容。

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这样的追踪器转换回普通的 NumPy 数组。简而言之,**不可能将追踪器转换为 NumPy 数组**,因为追踪器是具有给定形状和数据类型的*所有可能*值的抽象表示,而 NumPy 数组是该抽象类的具体成员。有关追踪器如何在 JAX 变换的上下文中工作的更多讨论,请参阅JIT 机制

将追踪器转换回数组的问题通常出现在另一个目标的上下文中,该目标与在运行时访问计算中的中间值有关。例如:

  • 如果您希望在运行时打印跟踪值以进行调试,可以考虑使用 jax.debug.print()

  • 如果您希望在转换后的 JAX 函数中调用非 JAX 代码,可以考虑使用 jax.pure_callback(),其示例可在 纯回调示例 中找到。

  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据或将数组内容记录到磁盘),可以考虑使用 jax.experimental.io_callback(),其示例可在 IO 回调示例 中找到。

有关运行时回调及其用法的更多信息,请参阅 JAX 中的外部回调

为什么某些 CUDA 库无法加载/初始化?#

在解析动态库时,JAX 使用通常的 动态链接器搜索模式。JAX 设置 RPATH 以指向通过 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装则优先使用它们。如果 ld.so 无法在其通常的搜索路径中找到您的 CUDA 运行时库,则必须在 LD_LIBRARY_PATH 中显式包含这些库的路径。确保您的 CUDA 文件可发现的最简单方法是简单地安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

偶尔,即使您已确保您的运行时库可发现,加载或初始化它们时仍可能存在一些问题。此类问题的一个常见原因是在运行时 CUDA 库初始化时内存不足。有时会发生这种情况,因为 JAX 会预先分配过大一部分当前可用的设备内存以加快执行速度,偶尔会导致无法为运行时 CUDA 库初始化留出足够的内存。

当运行多个 JAX 实例、与执行自身预分配的 TensorFlow 同时运行 JAX 或在 GPU 被其他进程大量使用的系统上运行 JAX 时,这种情况尤其容易发生。如有疑问,请尝试使用减少的预分配再次运行程序,方法是将 XLA_PYTHON_CLIENT_MEM_FRACTION 从默认值 .75 降低,或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false。有关更多详细信息,请参阅有关 JAX GPU 内存分配 的页面。