🔪 JAX - 最锋利的刀刃 🔪#
当您在意大利乡村散步时,当地人会毫不犹豫地告诉您,JAX 拥有 “una anima di pura programmazione funzionale”。
JAX 是一种用于表达和组合数值程序变换的语言。JAX 还可以编译针对 CPU 或加速器(GPU/TPU)的数值程序。JAX 在许多数值和科学程序中表现出色,但只有在使用我们将在下面描述的某些约束条件编写时才适用。
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
🔪 纯函数#
JAX 的变换和编译旨在仅在功能上纯粹的 Python 函数上工作:所有输入数据都通过函数参数传递,所有结果都通过函数结果输出。如果使用相同的输入调用纯函数,它将始终返回相同的结果。
以下是一些 JAX 行为与 Python 解释器不同的非功能纯函数示例。请注意,这些行为并非 JAX 系统保证的;使用 JAX 的正确方法是仅在功能上纯粹的 Python 函数上使用它。
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
即使 Python 函数在内部实际使用有状态对象,只要它不读取或写入外部状态,它就可以是功能上纯粹的
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
不建议在您想要 jit
的任何 JAX 函数或任何控制流原语中使用迭代器。原因是迭代器是一个 Python 对象,它引入了状态来检索下一个元素。因此,它与 JAX 函数式编程模型不兼容。在下面的代码中,有一些尝试将迭代器与 JAX 一起使用的错误示例。其中大多数会返回错误,但有些会给出意外的结果。
import jax.numpy as jnp
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0
🔪 就地更新#
在 NumPy 中,您习惯于这样做
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
但是,如果我们尝试就地更新 JAX 设备数组,我们会得到一个错误!(☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.ac.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html
允许就地修改变量会使程序分析和转换变得困难。JAX 要求程序是纯函数。
相反,JAX 提供了使用 .at
属性 的函数式数组更新。
️⚠️ 在 jit
的代码和 lax.while_loop
或 lax.fori_loop
中,切片的大小不能是参数值的函数,而只能是参数形状的函数 - 切片起始索引没有这样的限制。有关此限制的更多信息,请参阅下面的控制流部分。
数组更新:x.at[idx].set(y)
#
例如,上面的更新可以写成
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
与 NumPy 版本不同,JAX 的数组更新函数在非就地运行。也就是说,更新后的数组将作为新数组返回,原始数组不会被更新修改。
print("original array unchanged:\n", jax_array)
original array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
但是,在jit编译的代码中,如果 x.at[idx].set(y)
的输入值x
未被重用,编译器将优化数组更新为就地进行。
使用其他操作进行数组更新#
索引数组更新不限于简单地覆盖值。例如,我们可以执行如下索引加法
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
有关索引数组更新的更多详细信息,请参阅 .at 属性的文档
。
🔪 索引越界#
在 NumPy 中,您习惯于在索引数组超出其边界时抛出错误,如下所示
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
但是,从在加速器上运行的代码中抛出错误可能很困难或不可能。因此,JAX 必须为索引越界选择一些非错误行为(类似于无效浮点运算如何导致 NaN
)。当索引操作是数组索引更新(例如 index_add
或 scatter
之类的原语)时,将跳过在越界索引处的更新;当操作是数组索引检索(例如 NumPy 索引或 gather
之类的原语)时,索引将被钳制到数组的边界,因为必须返回一些内容。例如,最后一个数组值将从此索引操作返回
jnp.arange(10)[11]
Array(9, dtype=int32)
如果您想要对越界索引的行为进行更细粒度的控制,可以使用 ndarray.at
的可选参数;例如
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
请注意,由于上面描述的两种行为不是彼此的逆运算,因此反向模式自动微分(将索引更新转换为索引检索,反之亦然)不会保留越界索引的语义。因此,将 JAX 中的越界索引视为 未定义行为 的情况可能是个好主意。
还要注意,由于上面描述的两种行为不是彼此的逆运算,因此反向模式自动微分(将索引更新转换为索引检索,反之亦然)不会保留越界索引的语义。因此,将 JAX 中的越界索引视为 未定义行为 的情况可能是个好主意。
🔪 非数组输入:NumPy 与 JAX#
NumPy 通常很乐意接受 Python 列表或元组作为其 API 函数的输入
np.sum([1, 2, 3])
np.int64(6)
JAX 与此不同,通常会返回一个有用的错误
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
这是一个深思熟虑的设计选择,因为将列表或元组传递给跟踪函数会导致静默的性能下降,否则可能难以检测到。
例如,考虑以下允许列表输入的 jnp.sum
的宽松版本
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)
输出是我们期望的,但这隐藏了潜在的性能问题。在 JAX 的跟踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为一个单独的 JAX 变量,并被单独处理并推送到设备。这可以在上面 permissive_sum
函数的 jaxpr 中看到
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
列表的每个条目都作为单独的输入处理,导致跟踪和编译开销随列表大小线性增长。为了防止出现这样的意外情况,JAX 避免了将列表和元组隐式转换为数组。
如果您想将元组或列表传递给 JAX 函数,您可以先将其显式转换为数组
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
🔪 随机数#
如果所有由于
rand()
不好而导致结果有疑问的科学论文从图书馆书架上消失,那么每个书架上都会有一个与你的拳头一样大的空缺。 - 数值食谱
RNG 和状态#
您习惯于从 numpy 和其他库中使用有状态的伪随机数生成器 (PRNG),它们在幕后隐藏了许多细节,为您提供了一个现成的伪随机数源
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.8726807392173378
0.620972743814503
0.07376745132319462
在幕后,numpy 使用 Mersenne Twister PRNG 为其伪随机函数提供动力。PRNG 的周期为 \(2^{19937}-1\),并且在任何时候都可以用624 个 32 位无符号整数和一个位置来描述,该位置指示使用了多少“熵”。
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
每当需要随机数时,这个伪随机状态向量都会在幕后自动更新,在 Mersenne Twister 状态向量中“消耗” 2 个 uint32。
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
魔术 PRNG 状态的问题在于,很难推理它如何在不同的线程、进程和设备之间使用和更新,并且当熵生成和消耗的细节对最终用户隐藏时,它非常容易弄错。
Mersenne Twister PRNG 也已知存在一些问题,它具有 2.5kB 的大型状态大小,这会导致有问题的初始化问题。它无法通过现代 BigCrush 测试,并且通常很慢。
JAX PRNG#
JAX 而是实现了一个显式 PRNG,其中熵生成和消耗通过显式传递和迭代 PRNG 状态来处理。JAX 使用现代的Threefry 基于计数器的 PRNG,它是可拆分的。也就是说,它的设计允许我们分叉 PRNG 状态到新的 PRNG 中,用于并行随机生成。
随机状态由一个特殊的数组元素描述,我们称之为键
key = random.key(0)
key
Array((), dtype=key<fry>) overlaying:
[0 0]
JAX 的随机函数从 PRNG 状态生成伪随机数,但不会更改状态!
重复使用相同的状态会导致悲伤和单调,剥夺了最终用户获得生命赐予的混乱
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
相反,每当我们需要一个新的伪随机数时,我们都会拆分 PRNG 以获得可用的子键
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[0 0]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389]
我们传播键并在每次需要新随机数时生成新的子键
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055]
我们可以一次生成多个子键
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[-0.37533438]
[0.98645043]
[0.14553197]
🔪 控制流#
✔ Python 控制流 + 自动微分 ✔#
如果您只想对您的 Python 函数应用 grad
,您可以毫无问题地使用常规的 Python 控制流结构,就好像您使用的是 Autograd(或 Pytorch 或 TF Eager)一样。
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0
Python 控制流 + JIT#
在 jit
中使用控制流更复杂,默认情况下它有更多约束。
这可以工作
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
这也行
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
但至少默认情况下,这不是这样
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1330/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.ac.cn/en/latest/errors.html#jax.errors.TracerBoolConversionError
怎么回事!?
当我们jit
-编译一个函数时,我们通常希望编译一个适用于许多不同参数值的函数版本,这样我们就可以缓存和重用编译后的代码。这样一来,我们就不必在每次函数评估时都重新编译。
例如,如果我们在数组 jnp.array([1., 2., 3.], jnp.float32)
上评估一个 @jit
函数,我们可能希望编译一个可以在 jnp.array([4., 5., 6.], jnp.float32)
上重用以评估函数的代码,以节省编译时间。
为了获得您 Python 代码的视图,该视图对于许多不同的参数值有效,JAX 会在表示一组可能输入的抽象值上对其进行跟踪。有多个不同的抽象级别,不同的转换使用不同的抽象级别。
默认情况下,jit
在 ShapedArray
抽象级别上跟踪您的代码,其中每个抽象值表示具有固定形状和 dtype 的所有数组值的集合。例如,如果我们使用抽象值 ShapedArray((3,), jnp.float32)
进行跟踪,我们将获得一个函数视图,该视图可以重复用于对应数组集中的任何具体值。这意味着我们可以节省编译时间。
但是,这里有一个权衡:如果我们在一个未提交到特定具体值的 ShapedArray((), jnp.float32)
上跟踪 Python 函数,当我们遇到类似 if x < 3
这样的行时,表达式 x < 3
会计算出一个抽象的 ShapedArray((), jnp.bool_)
,它表示集合 {True, False}
。当 Python 尝试将其强制转换为具体的 True
或 False
时,我们会得到一个错误:我们不知道该走哪个分支,无法继续跟踪!权衡是,随着抽象层次的提高,我们对 Python 代码的理解更加全面(从而节省了重新编译的时间),但我们需要对 Python 代码施加更多限制才能完成跟踪。
好消息是,你可以自己控制这种权衡。通过让 jit
在更精细的抽象值上进行跟踪,你可以放宽可跟踪性约束。例如,使用 jit
的 static_argnums
参数,我们可以指定在某些参数的具体值上进行跟踪。以下是那个示例函数的另一个版本
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
以下是一个示例,这次涉及循环
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
实际上,循环被静态地展开。JAX 也可以在更高级别的抽象上进行跟踪,比如 Unshaped
,但这目前不是任何转换的默认行为。
️⚠️ 参数-值依赖形状的函数
这些控制流问题也会以更微妙的方式出现:我们想要jit 的数值函数不能根据参数值来专门化内部数组的形状(根据参数形状进行专门化是可以的)。作为一个简单的例子,让我们创建一个函数,其输出恰好依赖于输入变量 length
。
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1330/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnums
在我们的示例中 length
很少改变的情况下非常有用,但如果它经常改变,就会造成灾难!
最后,如果你的函数具有全局副作用,JAX 的跟踪器可能会导致奇怪的事情发生。一个常见的错误是试图在jit 的函数中打印数组
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)
结构化控制流原语#
JAX 中还有更多控制流选项。假设你想避免重新编译,但仍然想使用可跟踪的控制流,并且避免展开大型循环。那么你可以使用以下 4 个结构化控制流原语
lax.cond
可微分lax.while_loop
前向模式可微分lax.fori_loop
通常情况下前向模式可微分;如果端点是静态的,则前向和反向模式可微分。lax.scan
可微分
cond
#
python 等价物
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
提供另外两个函数,允许根据动态谓词进行分支
lax.select
就像lax.cond
的批处理版本,选择表示为预先计算的数组,而不是函数。lax.switch
就像lax.cond
,但允许在任意数量的可调用选择之间进行切换。
此外,jax.numpy
为这些函数提供了几个 numpy 风格的接口
jnp.where
有三个参数,是lax.select
的 numpy 风格包装器。jnp.piecewise
是lax.switch
的 numpy 风格包装器,但根据布尔条件列表而不是单个标量索引进行切换。jnp.select
的 API 与jnp.piecewise
类似,但选择表示为预先计算的数组,而不是函数。它通过对lax.select
的多次调用实现。
while_loop
#
python 等价物
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
#
python 等价物
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
总结#
\(\ast\) = 参数-值-独立循环条件 - 展开循环
🔪 动态形状#
在像 jax.jit
、jax.vmap
、jax.grad
等变换中使用的 JAX 代码要求所有输出数组和中间数组具有静态形状:也就是说,形状不能依赖于其他数组中的值。
例如,如果你要实现你自己的 jnp.nansum
版本,你可能会从类似这样的东西开始
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
在 JIT 和其他变换之外,这按预期工作
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0
如果你尝试将 jax.jit
或其他变换应用于此函数,它将出错
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.ac.cn/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
问题是 x_without_nans
的大小取决于 x
中的值,这也就是它的大小是动态的另一种说法。通常在 JAX 中,可以通过其他方法来避免对动态大小数组的需要。例如,这里可以使用 jnp.where
的三参数形式用零替换 NaN 值,从而计算出相同的结果,同时避免了动态形状
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
10.0
在其他出现动态形状数组的情况下,也可以采用类似的技巧。
🔪 NaN#
调试 NaN#
如果你想跟踪你的函数或梯度中出现 NaN 的位置,你可以通过以下方式打开 NaN 检查器
设置
JAX_DEBUG_NANS=True
环境变量;在你的主文件开头添加
jax.config.update("jax_debug_nans", True)
;在你的主文件中添加
jax.config.parse_flags_with_absl()
,然后使用类似--jax_debug_nans=True
的命令行标志设置选项;
这将导致计算在生成 NaN 时立即出错。打开此选项会在 XLA 生成的每个浮点类型值上添加一个 NaN 检查。这意味着对于不在 @jit
下的每个基本操作,值都会被拉回到主机并以 ndarray 的形式进行检查。对于 @jit
下的代码,会检查每个 @jit
函数的输出,如果存在 NaN,它会以非优化逐操作模式重新运行函数,实际上一次移除一层 @jit
。
可能会出现一些棘手的情况,比如只有在 @jit
下才会出现的 NaN,但在非优化模式下不会生成。在这种情况下,你会看到一个警告消息打印出来,但你的代码将继续执行。
如果 NaN 是在梯度评估的反向传播过程中生成的,当抛出异常时,在堆栈跟踪中向上几帧,你将位于 backward_pass 函数中,它本质上是一个简单的 jaxpr 解释器,它反向遍历基本操作序列。在下面的例子中,我们使用命令行 env JAX_DEBUG_NANS=True ipython
启动了一个 ipython repl,然后运行了以下代码
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
生成的 NaN 被捕获了。通过运行 %debug
,我们可以获得事后调试器。这在 @jit
下的函数中也能工作,如下面的示例所示。
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
当这段代码在 @jit
函数的输出中看到 NaN 时,它会调用非优化代码,因此我们仍然可以获得清晰的堆栈跟踪。我们可以使用 %debug
运行事后调试器来检查所有值,以找出错误。
⚠️ 如果你没有进行调试,你不应该打开 NaN 检查器,因为它会导致大量设备-主机往返和性能下降!
⚠️ NaN 检查器不适用于 pmap
。要调试 pmap
代码中的 NaN,一种尝试方法是用 vmap
替换 pmap
。
🔪 双精度 (64 位)#
目前,JAX 默认情况下强制执行单精度数字,以减轻 Numpy API 积极地将操作数提升为 double
的倾向。对于许多机器学习应用来说,这是期望的行为,但它可能会让你感到意外!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1330/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
要使用双精度数字,你需要在启动时设置 jax_enable_x64
配置变量。
有几种方法可以做到这一点
你可以通过设置环境变量
JAX_ENABLE_X64=True
来启用 64 位模式。你可以在启动时手动设置
jax_enable_x64
配置标志# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
你可以使用
absl.app.run(main)
解析命令行标志。import jax jax.config.config_with_absl()
如果你希望 JAX 帮你执行 absl 解析,即你不想使用
absl.app.run(main)
,你可以使用import jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
注意,#2-#4 适用于 JAX 的所有配置选项。
然后,我们可以确认 x64
模式已启用,例如
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
注意事项#
⚠️ XLA 不支持所有后端上的 64 位卷积!
🔪 与 NumPy 的差异#
虽然 jax.numpy
尽力复制 NumPy API 的行为,但在某些情况下,它们的行为仍然存在差异。上面几节详细讨论了许多这类情况;这里列出了几个 API 不同的地方。
对于二元运算,JAX 的类型提升规则与 NumPy 使用的规则略有不同。有关详细信息,请参阅 类型提升语义。
在执行不安全的类型转换时(例如,目标数据类型不能表示输入值的转换),JAX 的行为可能取决于后端,并且通常与 NumPy 的行为不同。Numpy 允许通过
casting
参数控制这些情况下的结果(请参阅np.ndarray.astype
);JAX 不提供任何此类配置,而是直接继承 XLA:ConvertElementType 的行为。以下是不安全转换的一个示例,NumPy 和 JAX 的结果不同
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
这种不匹配通常出现在将极值从浮点型转换为整型或反之亦然时。
结束语#
如果这里没有涵盖的内容导致你痛苦不堪,请告诉我们,我们会扩展这些入门级建议!