使用 JIT 的控制流和逻辑运算符#
当以 eager 方式执行(在 jit
之外)时,JAX 代码与 Python 控制流和逻辑运算符的工作方式与 NumPy 代码相同。将控制流和逻辑运算符与 jit
一起使用会更复杂。
简而言之,Python 控制流和逻辑运算符在 JIT 编译时进行评估,因此编译后的函数表示通过控制流图的单条路径(逻辑运算符通过短路影响路径)。如果路径取决于输入的值,则(默认情况下)无法 JIT 编译该函数。该路径可能取决于输入的形状或 dtype,并且每次在具有新形状或 dtype 的输入上调用该函数时,都会重新编译该函数。
from jax import grad, jit
import jax.numpy as jnp
例如,这可以工作
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
这也可以
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
但是这不行,至少默认情况下不行
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[4], line 9
6 return -4 * x
8 # This will fail!
----> 9 f(2)
[... skipping hidden 13 frame]
Cell In[4], line 3, in f(x)
1 @jit
2 def f(x):
----> 3 if x < 3:
4 return 3. * x ** 2
5 else:
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1498, in concretization_function_error.<locals>.error(self, arg)
1497 def error(self, arg):
-> 1498 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_827/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
这个也不行
@jit
def g(x):
return (x > 0) and (x < 3)
# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[5], line 6
3 return (x > 0) and (x < 3)
5 # This will fail!
----> 6 g(2)
[... skipping hidden 13 frame]
Cell In[5], line 3, in g(x)
1 @jit
2 def g(x):
----> 3 return (x > 0) and (x < 3)
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1498, in concretization_function_error.<locals>.error(self, arg)
1497 def error(self, arg):
-> 1498 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_827/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
怎么回事!?
当我们对函数进行 jit
编译时,我们通常希望编译一个适用于多种不同参数值的函数版本,以便我们可以缓存和重用编译后的代码。这样我们就无需在每次函数评估时都重新编译。
例如,如果我们对数组 jnp.array([1., 2., 3.], jnp.float32)
评估一个 @jit
函数,我们可能希望编译可以重用的代码,以便在 jnp.array([4., 5., 6.], jnp.float32)
上评估该函数,从而节省编译时间。
为了获得适用于多种不同参数值的 Python 代码视图,JAX 使用 ShapedArray
抽象作为输入对其进行跟踪,其中每个抽象值代表具有固定形状和 dtype 的所有数组值的集合。例如,如果我们使用抽象值 ShapedArray((3,), jnp.float32)
进行跟踪,我们将获得一个可以重用于相应数组集合中任何具体值的函数视图。这意味着我们可以节省编译时间。
但这里存在一个权衡:如果我们在一个未提交给特定具体值的 ShapedArray((), jnp.float32)
上跟踪一个 Python 函数,当我们遇到像 if x < 3
这样的行时,表达式 x < 3
会计算出一个抽象的 ShapedArray((), jnp.bool_)
,它代表集合 {True, False}
。当 Python 尝试将其强制转换为具体的 True
或 False
时,我们会收到一个错误:我们不知道要采取哪个分支,并且无法继续跟踪!权衡之处在于,使用更高层次的抽象,我们可以获得 Python 代码的更通用视图(从而节省重新编译的时间),但我们需要对 Python 代码施加更多约束才能完成跟踪。
好消息是你可以自己控制这种权衡。通过让 jit
在更精细的抽象值上进行跟踪,你可以放宽可追溯性约束。例如,使用 jit
的 static_argnames
(或 static_argnums
)参数,我们可以指定在某些参数的具体值上进行跟踪。这是之前的示例函数
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnames='x')
print(f(2.))
12.0
这是另一个例子,这次涉及一个循环
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnames='n')
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
实际上,循环被静态展开了。JAX 还可以在更高级别的抽象(如 Unshaped
)上进行跟踪,但目前这不是任何转换的默认设置
️⚠️ **具有依赖于参数值的形状的函数**
这些控制流问题也以更微妙的方式出现:我们想要 **jit** 的数值函数不能根据参数值来专门化内部数组的形状(根据参数 **形状** 进行专门化是可以的)。作为一个简单的示例,让我们创建一个输出恰好取决于输入变量 length
的函数。
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 3
1 bad_example_jit = jit(example_fun)
2 # this will fail:
----> 3 bad_example_jit(10, 4)
[... skipping hidden 13 frame]
Cell In[8], line 2, in example_fun(length, val)
1 def example_fun(length, val):
----> 2 return jnp.ones((length,)) * val
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6174, in ones(shape, dtype, device)
6172 raise TypeError("expected sequence object with len >= 0 or a single integer")
6173 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
-> 6174 shape = canonicalize_shape(shape)
6175 dtypes.check_user_dtype_supported(dtype, "ones")
6176 return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:102, in canonicalize_shape(shape, context)
100 return core.canonicalize_shape((shape,), context)
101 else:
--> 102 return core.canonicalize_shape(shape, context)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1643, in canonicalize_shape(shape, context)
1641 except TypeError:
1642 pass
-> 1643 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_827/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
如果示例中的 length
很少更改,则 static_argnames
会很方便,但如果它经常更改,那将是灾难性的!
最后,如果你的函数具有全局副作用,JAX 的跟踪器可能会导致奇怪的事情发生。一个常见的陷阱是尝试在 **jit** 函数内部打印数组
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Array(4, dtype=int32, weak_type=True)
结构化控制流原语#
JAX 中还有更多控制流选项。假设你想避免重新编译,但仍然想使用可跟踪的控制流,并避免展开大型循环。那么你可以使用以下 4 个结构化控制流原语
lax.cond
可微分lax.while_loop
前向模式可微分lax.fori_loop
一般情况下是 前向模式可微分;如果端点是静态的,则是 前向和反向模式可微分。lax.scan
可微分
cond
#
python 等效
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
提供了另外两个函数,允许在动态谓词上进行分支
lax.select
就像lax.cond
的批量版本,选择以预先计算的数组而不是函数的形式表示。lax.switch
类似于lax.cond
,但允许在任意数量的可调用选择之间切换。
此外,jax.numpy
为这些函数提供了几个 numpy 风格的接口
jnp.where
与三个参数是lax.select
的 numpy 风格包装器。jnp.piecewise
是lax.switch
的 numpy 风格包装器,但它基于一系列布尔条件而不是单个标量索引进行切换。jnp.select
具有类似于jnp.piecewise
的 API,但选择以预先计算的数组而不是函数的形式给出。它通过多次调用lax.select
来实现。
while_loop
#
python 等效
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
#
python 等效
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
摘要#
\(\ast\) = 与参数值无关的循环条件 - 展开循环
逻辑运算符#
jax.numpy
提供了 logical_and
、logical_or
和 logical_not
,它们在数组上逐元素操作,并且可以在 jit
下进行评估而无需重新编译。与它们的 Numpy 对应项一样,二元运算符不会短路。按位运算符(&
、|
、~
)也可以与 jit
一起使用。
例如,考虑一个检查其输入是否为正偶整数的函数。当输入为标量时,纯 Python 和 JAX 版本给出相同的答案。
def python_check_positive_even(x):
is_even = x % 2 == 0
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
return is_even and (x > 0)
@jit
def jax_check_positive_even(x):
is_even = x % 2 == 0
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
return jnp.logical_and(is_even, x > 0)
print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True
当 JAX 版本使用 logical_and
应用于数组时,它会返回逐元素的值。
x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False True False]
Python 逻辑运算符在应用于多个元素的 JAX 数组时会报错,即使没有 jit
也是如此。这复制了 NumPy 的行为。
print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))
Cell In[15], line 4, in python_check_positive_even(x)
2 is_even = x % 2 == 0
3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/array.py:292, in ArrayImpl.__bool__(self)
291 def __bool__(self):
--> 292 core.check_bool_conversion(self)
293 return bool(self._value)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:655, in check_bool_conversion(arr)
652 raise ValueError("The truth value of an empty array is ambiguous. Use"
653 " `array.size > 0` to check that an array is not empty.")
654 if arr.size > 1:
--> 655 raise ValueError("The truth value of an array with more than one element"
656 " is ambiguous. Use a.any() or a.all()")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Python 控制流 + 自动微分#
请记住,以上关于控制流和逻辑运算符的约束仅与 jit
相关。如果你只想将 grad
应用于你的 python 函数,而没有 jit
,你可以像使用 Autograd(或 Pytorch 或 TF Eager)一样,毫无问题地使用常规的 Python 控制流结构。
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0