常见问题解答 (FAQ)#
我们正在这里收集常见问题的答案。欢迎贡献!
jit
更改了我的函数的行为#
如果你的 Python 函数在使用 jax.jit()
后更改了行为,可能是你的函数使用了全局状态或具有副作用。在以下代码中,impure_func
使用全局变量 y
,并且由于 print
而具有副作用
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
没有 jit
的输出是
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
而使用 jit
的输出是
Inside: 0
Result: 0
Result: 1
Result: 2
对于 jax.jit()
,该函数使用 Python 解释器执行一次,此时会发生 Inside
打印,并观察到 y
的第一个值。然后,该函数被编译并缓存,并使用不同的 x
值执行多次,但使用相同的 y
的第一个值。
其他阅读材料
jit
更改了输出的精确数值#
有时,用户会惊讶于用 jit()
包装函数可以更改函数的输出。例如
>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649
输出的这种细微差异来自 XLA 编译器中的优化:在编译期间,XLA 有时会重新排列或省略某些操作,以使整体计算更有效率。
在这种情况下,XLA 利用对数的属性,用 0.5 * log(x)
替换 log(sqrt(x))
,这是一个数学上相同的表达式,可以比原始表达式更有效地计算。输出的差异来自于浮点算术只是对真实数学的近似,因此计算相同表达式的不同方式可能具有细微不同的结果。
有时,XLA 的优化可能会导致更大的差异。考虑以下示例
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
在非 JIT 编译的逐个操作模式下,结果为 inf
,因为 jnp.exp(x)
溢出并返回 inf
。但是,在 JIT 下,XLA 识别到 log
是 exp
的逆运算,并从编译后的函数中删除这些操作,简单地返回输入。在这种情况下,JIT 编译生成了更精确的真实结果的浮点近似值。
不幸的是,XLA 的代数简化的完整列表没有详细记录,但是如果您熟悉 C++ 并对 XLA 编译器进行的优化类型感到好奇,可以在源代码中看到它们:algebraic_simplifier.cc。
jit
装饰的函数编译速度非常慢#
如果你的 jit
装饰的函数在第一次调用时需要数十秒(或更长时间)才能运行,但在再次调用时执行速度很快,则表示 JAX 需要很长时间来跟踪或编译你的代码。
这通常表明调用你的函数会在 JAX 的内部表示中生成大量代码,通常是因为它大量使用 Python 控制流,例如 for
循环。对于少量循环迭代,Python 还可以,但是如果你需要很多循环迭代,则应重写代码以使用 JAX 的结构化控制流原语(例如 lax.scan()
)或避免使用 jit
包装循环(你仍然可以在循环内部使用 jit
装饰的函数)。
如果你不确定这是否是问题,可以尝试在你的函数上运行 jax.make_jaxpr()
。如果输出的长度为数百或数千行,则可以预期编译速度会很慢。
有时,不清楚如何重写代码以避免 Python 循环,因为你的代码使用了许多具有不同形状的数组。在这种情况下,建议的解决方案是使用诸如 jax.numpy.where()
之类的函数在具有固定形状的填充数组上执行计算。
如果你的函数由于其他原因编译速度很慢,请在 GitHub 上打开一个问题。
如何在方法中使用 jit
?#
jax.jit()
的大多数示例都涉及装饰独立的 Python 函数,但是装饰类中的方法会引入一些复杂性。例如,考虑以下简单类,我们在其中使用了标准 jit()
注释在方法上
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
但是,当你尝试调用此方法时,此方法将导致错误
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
问题在于该函数的第一个参数是 self
,其类型为 CustomClass
,而 JAX 不知道如何处理此类型。在这种情况下,我们可以使用三种基本策略,我们将在下面讨论它们。
策略 1:JIT 编译的辅助函数#
最直接的方法是创建一个类外部的辅助函数,可以以正常方式对其进行 JIT 装饰。例如
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
结果将按预期工作
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
这种方法的优点是它简单、明确,并且避免了教 JAX 如何处理 CustomClass
类型对象的需求。但是,你可能希望将所有方法逻辑保留在同一位置。
策略 2:将 self
标记为静态#
另一种常见的模式是使用 static_argnums
将 self
参数标记为静态。但是,必须小心执行此操作,以避免出现意外结果。你可能很想简单地这样做
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
如果调用该方法,则它将不再引发错误
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
但是,这里有一个陷阱:如果在第一次方法调用后更改对象,则后续的方法调用可能会返回不正确的结果
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
这是为什么呢?当你将对象标记为静态时,它实际上将用作 JIT 内部编译缓存中的字典键,这意味着它的哈希值(即 hash(obj)
)相等性(即 obj1 == obj2
)和对象标识(即 obj1 is obj2
)将被假定为具有一致的行为。自定义对象的默认 __hash__
是其对象 ID,因此 JAX 无法知道已更改的对象应触发重新编译。
你可以通过为对象定义适当的 __hash__
和 __eq__
方法来部分解决此问题;例如
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
(有关覆盖 __hash__
时的要求的更多讨论,请参见 object.__hash__()
文档)。
只要你从不更改对象,这应该可以与 JIT 和其他转换一起正常工作。用作哈希键的对象的更改会导致一些细微的问题,这就是为什么例如可变的 Python 容器(例如,dict
,list
)不定义 __hash__
,而其不可变的对应项(例如,tuple
)定义 __hash__
。
如果你的类依赖于就地更改(例如在其方法中设置 self.attr = ...
),则你的对象并非真正“静态”,将其标记为静态可能会导致问题。幸运的是,在这种情况下,还有另一种选择。
策略 3:使 CustomClass
成为 PyTree#
正确 JIT 编译类方法的最灵活的方法是将类型注册为自定义 PyTree 对象;请参阅 扩展 pytrees。这使你可以精确地指定应将类的哪些组件视为静态,哪些组件视为动态。以下是它的外观
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
这当然更复杂,但是它解决了与上面使用的较简单方法相关的所有问题
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
只要你的 tree_flatten
和 tree_unflatten
函数正确处理了类中的所有相关属性,你就应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注释。
控制设备上的数据和计算放置#
让我们首先看看 JAX 中数据和计算放置的原则。
在 JAX 中,计算遵循数据放置。JAX 数组具有两个放置属性:1) 数据所在的设备;2) 是否提交到设备(有时将数据称为粘性到设备)。
默认情况下,JAX 数组以未提交状态放置在默认设备上 (jax.devices()[0]
),默认情况下这是第一个 GPU 或 TPU。如果不存在 GPU 或 TPU,则 jax.devices()[0]
是 CPU。可以使用 jax.default_device()
上下文管理器临时覆盖默认设备,或者通过设置环境变量 JAX_PLATFORMS
或 absl 标志 --jax_platforms
为“cpu”、“gpu”或“tpu”来为整个进程设置默认设备(JAX_PLATFORMS
也可以是平台列表,它决定了按优先级顺序可用的平台)。
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())
{CudaDevice(id=0)}
涉及未提交数据的计算在默认设备上执行,结果在默认设备上未提交。
也可以使用带有 device
参数的 jax.device_put()
将数据显式放置在设备上,在这种情况下,数据会提交到设备
>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])
>>> print(arr.devices())
{CudaDevice(id=2)}
涉及某些已提交输入的计算将在已提交的设备上进行,结果将提交到同一设备。在提交到多个设备的参数上调用操作将引发错误。
您也可以使用不带 device
参数的 jax.device_put()
。如果数据已经在设备上(已提交或未提交),则保持原样。如果数据不在任何设备上,即它是常规 Python 或 NumPy 值,则会将其以未提交状态放置在默认设备上。
Jitted 函数的行为类似于任何其他原始操作,它们将遵循数据,如果在提交到多个设备的数据上调用,将显示错误。
(在 2021 年 3 月的 PR #6002 之前,数组常数的创建存在一些延迟,因此 jax.device_put(jnp.zeros(...), jax.devices()[1])
或类似的代码实际上会在 jax.devices()[1]
上创建零数组,而不是在默认设备上创建数组然后移动它。但是此优化已被删除,以简化实现。)
(截至 2020 年 4 月,jax.jit()
具有影响设备放置的 device 参数。该参数是实验性的,可能会被删除或更改,不建议使用。)
对于一个完整的示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data
。
基准测试 JAX 代码#
你刚刚将一个棘手的函数从 NumPy/SciPy 移植到 JAX。这真的加快了速度吗?
在测量使用 JAX 代码的速度时,请记住与 NumPy 的这些重要差异
JAX 代码是即时 (JIT) 编译的。 大多数用 JAX 编写的代码都可以通过支持 JIT 编译的方式编写,这可以使其运行快得多(请参阅 To JIT or not to JIT)。要从 JAX 获得最大性能,您应该在最外层的函数调用上应用
jax.jit()
。请记住,第一次运行 JAX 代码时会比较慢,因为它正在被编译。即使您在自己的代码中不使用
jit
,也是如此,因为 JAX 的内置函数也是 JIT 编译的。JAX 具有异步调度。 这意味着您需要调用
.block_until_ready()
以确保计算实际上已经发生(请参阅 异步调度)。JAX 默认仅使用 32 位 dtypes。 为了进行公平比较,您可能需要在 NumPy 中显式使用 32 位 dtypes,或在 JAX 中启用 64 位 dtypes(请参阅 Double (64 bit) precision)。
在 CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数所需的时间,您可能需要先将数据传输到您要运行它的设备上(请参阅 控制设备上的数据和计算放置)。
这是一个如何将所有这些技巧组合成一个微基准测试的示例,用于比较 JAX 和 NumPy,使用了 IPython 方便的 %time 和 %timeit magics
import numpy as np
import jax.numpy as jnp
import jax
def f(x): # function we're benchmarking (works in both NumPy & JAX)
return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
%timeit f(x_np) # measure NumPy runtime
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
在 Colab 中使用 GPU 运行时,我们看到
NumPy 在 CPU 上每次评估需要 16.2 毫秒
JAX 需要 1.26 毫秒将 NumPy 数组复制到 GPU 上
JAX 需要 193 毫秒来编译函数
JAX 在 GPU 上每次评估需要 485 微秒
在这种情况下,我们看到一旦传输数据并编译函数,在 GPU 上使用 JAX 进行重复评估的速度大约快 30 倍。
这是一个公平的比较吗?也许。最终重要的性能是用于运行完整应用程序的性能,这不可避免地包括一定数量的数据传输和编译。此外,我们很小心地选择了足够大的数组(1000x1000)和足够密集的计算(@
运算符执行矩阵-矩阵乘法),以分摊 JAX/加速器与 NumPy/CPU 相比增加的开销。例如,如果我们将此示例切换为使用 10x10 输入,则 JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。
JAX 比 NumPy 快吗?#
用户经常尝试通过此类基准测试回答的一个问题是 JAX 是否比 NumPy 快;由于这两个包的差异,没有简单的答案。
总的来说
NumPy 操作是急切地、同步地执行的,并且仅在 CPU 上执行。
JAX 操作可以在急切执行或编译后执行(如果在
jit()
中);它们是异步调度的(请参阅 异步调度);并且它们可以在 CPU、GPU 或 TPU 上执行,每个设备的性能特性都差异很大且不断发展。
这些架构差异使得在 NumPy 和 JAX 之间进行有意义的直接基准比较变得困难。
此外,这些差异导致了包之间不同的工程重点:例如,NumPy 在降低单个数组操作的每次调用调度开销方面投入了大量精力,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有多种方法可以避免调度开销(例如,JIT 编译、异步调度、批处理转换等),因此减少每次调用开销并不是优先事项。
考虑到所有这些,总而言之:如果您在 CPU 上进行单个数组操作的微基准测试,由于其较低的每次操作调度开销,您通常可以预期 NumPy 的性能优于 JAX。如果您在 GPU 或 TPU 上运行代码,或者正在基准测试 CPU 上更复杂的 JIT 编译操作序列,您通常可以预期 JAX 的性能优于 NumPy。
不同类型的 JAX 值#
在转换函数的过程中,JAX 将某些函数参数替换为特殊的跟踪器值。
如果您使用 print
语句,您可能会看到这种情况
def func(x):
print(x)
return jnp.cos(x)
res = jax.jit(func)(0.)
上面的代码确实返回了正确的值 1.
,但它还为 x
的值打印了 Traced<ShapedArray(float32[])>
。通常,JAX 以透明的方式在内部处理这些跟踪器值,例如,在用于实现 jax.numpy
函数的数字 JAX 原始函数中。这就是为什么 jnp.cos
在上面的示例中起作用的原因。
更准确地说,对于 JAX 转换函数的参数,引入了跟踪器值,除了由特殊参数标识的参数,例如 jax.jit()
的 static_argnums
或 jax.pmap()
的 static_broadcasted_argnums
。通常,涉及至少一个跟踪器值的计算将产生一个跟踪器值。除了跟踪器值,还有常规 Python 值:在 JAX 转换之外计算的值,或者来自上述某些 JAX 转换的静态参数,或者仅从其他常规 Python 值计算的值。这些值是在没有 JAX 转换的情况下到处使用的值。
跟踪器值携带一个抽象值,例如,具有有关数组形状和 dtype 信息的 ShapedArray
。我们将在此处将此类跟踪器称为抽象跟踪器。一些跟踪器,例如,为自动微分转换的参数引入的跟踪器,携带实际上包括常规数组数据的 ConcreteArray
抽象值,并用于例如解析条件语句。我们将在此处将此类跟踪器称为具体跟踪器。从这些具体跟踪器计算出的跟踪器值,可能与其他常规值组合,会产生具体跟踪器。具体值是常规值或具体跟踪器。
通常情况下,从追踪器值计算得到的值本身也是追踪器值。只有极少数例外情况,即当计算完全可以使用追踪器携带的抽象值完成时,结果可以是常规值。例如,使用 ShapedArray
抽象值获取追踪器的形状。另一个例子是当显式地将具体的追踪器值转换为常规类型时,例如 int(x)
或 x.astype(float)
。还有一种情况是对于 bool(x)
,当具体化使其成为可能时,它会产生一个 Python 布尔值。这种情况尤其重要,因为它在控制流中经常出现。
以下是转换如何引入抽象或具体的追踪器:
jax.jit()
:为所有位置参数引入抽象追踪器,除了由static_argnums
指定的参数,这些参数保持为常规值。jax.pmap()
:为所有位置参数引入抽象追踪器,除了由static_broadcasted_argnums
指定的参数。jax.vmap()
,jax.make_jaxpr()
,xla_computation()
:为所有位置参数引入抽象追踪器。jax.jvp()
和jax.grad()
为所有位置参数引入具体追踪器。一个例外是当这些转换位于外部转换中,并且实际参数本身是抽象追踪器时;在这种情况下,由自动微分转换引入的追踪器也是抽象追踪器。所有高阶控制流原语(
lax.cond()
,lax.while_loop()
,lax.fori_loop()
,lax.scan()
)在处理函数时会引入抽象追踪器,无论是否正在进行 JAX 转换。
当你的代码只能在常规 Python 值上操作时,所有这些都相关,例如基于数据的条件控制流代码。
def divide(x, y):
return x / y if y >= 1. else 0.
如果我们想应用 jax.jit()
,我们必须确保指定 static_argnums=1
以确保 y
保持为常规值。这是由于布尔表达式 y >= 1.
,它需要具体值(常规值或追踪器)。如果我们显式地写 bool(y >= 1.)
,或 int(y)
,或 float(y)
,也会发生同样的情况。
有趣的是,jax.grad(divide)(3., 2.)
可以工作,因为 jax.grad()
使用具体追踪器,并使用 y
的具体值解析条件。
缓冲区捐赠#
当 JAX 执行计算时,它会在设备上使用缓冲区来存储所有输入和输出。如果您知道其中一个输入在计算后不再需要,并且它与其中一个输出的形状和元素类型匹配,则可以指定您希望将相应的输入缓冲区捐赠给用于保存输出。这将减少执行所需的内存,减少量为捐赠缓冲区的大小。
如果您有类似以下模式的内容,您可以使用缓冲区捐赠:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
您可以将其视为一种对不可变的 JAX 数组执行内存高效的功能性更新的方法。在计算的边界内,XLA 可以为您进行这种优化,但在 jit/pmap 边界,您需要向 XLA 保证,在调用捐赠函数后,您将不会使用捐赠的输入缓冲区。
您可以使用函数 jax.jit()
、jax.pjit()
和 jax.pmap()
的 donate_argnums 参数来实现此目的。此参数是位置参数列表中的索引序列(从 0 开始)。
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
请注意,当使用关键字参数调用函数时,此功能目前不起作用!以下代码不会捐赠任何缓冲区:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
如果捐赠缓冲区的参数是 pytree,则会捐赠其所有组件的缓冲区。
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)
不允许捐赠后续在计算中使用的缓冲区,JAX 会报错,因为 y 的缓冲区在捐赠后已失效。
# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1 # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer
如果捐赠的缓冲区未使用,您会收到警告,例如,因为捐赠的缓冲区比可用于输出的缓冲区多。
# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}
如果没有输出的形状与捐赠匹配,则捐赠也可能未使用。
y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}
使用 where
时,梯度包含 NaN#
如果使用 where
定义一个函数以避免未定义的值,如果您不小心,可能会获得用于反向微分的 NaN
。
def my_log(x):
return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
一个简短的解释是,在 grad
计算期间,对应于未定义的 jnp.log(x)
的伴随是 NaN
,并且它会累积到 jnp.where
的伴随。编写此类函数的正确方法是确保在部分定义的函数内部有一个 jnp.where
,以确保伴随始终是有限的。
def safe_for_grad_log(x):
return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
除了原始的 jnp.where
之外,可能还需要内部的 jnp.where
,例如:
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
其他阅读材料
为什么基于排序顺序的函数的梯度为零?#
如果您定义一个使用依赖于输入相对顺序的操作(例如 max
、greater
、argsort
等)来处理输入的函数,您可能会惊讶地发现梯度处处为零。这是一个示例,其中我们将 f(x) 定义为一个阶跃函数,当 x 为负数时返回 0,当 x 为正数时返回 1:
import jax
import numpy as np
import jax.numpy as jnp
def f(x):
return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x) = {f(x)}")
# f(x) = [0. 0. 0. 1. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]
梯度处处为零这一事实起初可能会令人困惑:毕竟,输出确实会响应输入而变化,那么梯度怎么可能是零呢?但是,在这种情况下,零结果被证明是正确的。
为什么会这样?请记住,微分测量的是给定 x
的无穷小变化时 f
的变化。对于 x=1.0
,f
返回 1.0
。如果我们扰动 x
使其稍微大或小一点,这不会改变输出,因此根据定义,grad(f)(1.0)
应该为零。对于所有大于零的 f
值,此逻辑都成立:无穷小地扰动输入不会更改输出,因此梯度为零。类似地,对于所有小于零的 x
值,输出为零。扰动 x
不会改变此输出,因此梯度为零。这给我们留下了 x=0
的棘手情况。当然,如果您向上扰动 x
,它将改变输出,但这有问题:x
的无穷小变化会导致函数值发生有限的变化,这意味着梯度是未定义的。幸运的是,在这种情况下,我们还有另一种测量梯度的方法:我们向下扰动函数,在这种情况下,输出不会改变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个已定义,另一个未定义,我们使用已定义的那个。在此梯度的定义下,数学上和数值上,此函数的梯度处处为零。
问题源于我们的函数在 x = 0
处存在不连续性。我们的 f
本质上是一个 Heaviside 阶跃函数,我们可以使用 Sigmoid 函数 作为平滑的替代品。当 x 远离零时,sigmoid 函数近似等于 heaviside 函数,但它用平滑的可微曲线取代了 x = 0
处的不连续性。由于使用了 jax.nn.sigmoid()
,我们得到了类似的计算,并且具有定义良好的梯度。
def g(x):
return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
print(f"g(x) = {g(x)}")
# g(x) = [0. 0.27 0.5 0.73 1. ]
print(f"dg(x) = {dg(x)}")
# dg(x) = [0. 0.2 0.25 0.2 0. ]
jax.nn
子模块还具有其他常用基于秩的函数的平滑版本,例如 jax.nn.softmax()
可以替代 jax.numpy.argmax()
的使用,jax.nn.soft_sign()
可以替代 jax.numpy.sign()
的使用,jax.nn.softplus()
或 jax.nn.squareplus()
可以替代 jax.nn.relu()
的使用,等等。
如何将 JAX Tracer 转换为 NumPy 数组?#
在运行时检查转换后的 JAX 函数时,您会发现数组值被 Tracer
对象替换。
@jax.jit
def f(x):
print(type(x))
return x
f(jnp.arange(5))
这将打印以下内容
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
一个常见的问题是如何将这种 tracer 转换回普通的 NumPy 数组。简而言之,不可能将 Tracer 转换为 NumPy 数组,因为 tracer 是具有给定形状和 dtype 的每个可能值的抽象表示,而 NumPy 数组是该抽象类的具体成员。有关 tracer 如何在 JAX 转换的上下文中工作的更多讨论,请参阅 JIT 机制。
将 Tracer 转换回数组的问题通常出现在另一个目标的上下文中,该目标与在运行时访问计算中的中间值有关。例如
如果您希望在运行时打印跟踪值以进行调试,您可以考虑使用
jax.debug.print()
。如果您希望在转换后的 JAX 函数中调用非 JAX 代码,您可以考虑使用
jax.pure_callback()
,其示例可在 纯回调示例 中找到。如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据,或将数组内容记录到磁盘),您可以考虑使用
jax.experimental.io_callback()
,其示例可在 IO 回调示例 中找到。
有关运行时回调及其使用示例的更多信息,请参阅 JAX 中的外部回调。
为什么某些 CUDA 库加载/初始化失败?#
在解析动态库时,JAX 使用常用的 动态链接器搜索模式。JAX 将 RPATH
设置为指向 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装则优先使用它们。如果 ld.so
无法在其常用搜索路径中找到您的 CUDA 运行时库,那么您必须在 LD_LIBRARY_PATH
中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是简单地安装 nvidia-*-cu12
pip 包,这些包包含在标准的 jax[cuda_12]
安装选项中。
有时,即使您已确保运行时库可被发现,在加载或初始化它们时可能仍然存在一些问题。此类问题的常见原因是 CUDA 库在运行时初始化时内存不足。有时会发生这种情况,因为 JAX 会预先分配过大的当前可用设备内存块以加快执行速度,有时会导致剩余内存不足以进行运行时 CUDA 库初始化。
当运行多个 JAX 实例,与执行自身预分配的 TensorFlow 并行运行 JAX,或者在 GPU 被其他进程大量利用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试使用减少的预分配再次运行程序,方法是将 XLA_PYTHON_CLIENT_MEM_FRACTION
从默认值 .75
降低,或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false
。有关更多详细信息,请参阅 JAX GPU 内存分配页面。