checkify
转换#
摘要: Checkify 允许您将可 jit
的运行时错误检查(例如,越界索引)添加到您的 JAX 代码中。将 checkify.checkify
转换与类似断言的 checkify.check
函数一起使用,以向 JAX 代码添加运行时检查
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
您还可以使用 checkify 自动添加常见检查
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
函数式检查#
类似断言的检查 API 本身不是函数式的:它会像断言一样,作为副作用引发 Python 异常。因此,它不能与 jit
、 pmap
、 pjit
或 scan
一起分阶段执行。
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是,checkify 转换将这些影响函数化(或解除)。经过 checkify 转换的函数会返回一个错误值作为新的输出,并保持函数式纯洁。这种函数化意味着,我们可以随意地将 checkify 转换后的函数与分阶段/转换进行组合。
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
为什么 JAX 需要 checkify?#
在某些 JAX 转换下,可以使用普通的 Python 断言来表达运行时错误检查,例如仅使用 jax.grad
和 jax.numpy
时。
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
但是,普通的断言在 jit
、 pmap
、 pjit
或 scan
中不起作用。在这些情况下,数值计算是分阶段执行的,而不是在 Python 执行期间急切地求值,因此数值不可用。
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX 转换的语义依赖于函数式纯洁,尤其是在组合多个转换时,那么我们如何在不破坏这一切的情况下提供错误机制呢?除了需要新的 API 之外,情况还更复杂:XLA HLO 不支持断言或抛出错误,因此,即使我们有一个 JAX API 能够分阶段执行断言,我们又如何将这些断言降级到 XLA 呢?
你可以想象手动向函数添加运行时检查并输出表示错误的值。
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
错误是函数计算的常规值,并且在 f_checked
之外引发错误。f_checked
是函数式纯洁的,因此我们通过构建就知道它已经可以与 jit
、pmap、pjit、scan 和所有 JAX 转换一起使用。唯一的问题是这种管道可能很麻烦!
checkify
为你执行此重写:包括通过函数传递错误值,将检查重写为布尔运算并将结果与跟踪的错误值合并,并将最终错误值作为 checkify 后的函数的输出返回。
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
我们称之为函数化或解除调用 check 引入的影响。(在上面的“手动”示例中,错误值只是一个布尔值。checkify 的错误值在概念上相似,但也跟踪错误消息并公开 throw 和 get 方法;请参阅 jax.experimental.checkify
)。checkify.check
还允许你通过将运行时值作为错误消息的格式参数提供,来将运行时值添加到错误消息中。
你现在可以手动使用运行时检查来检测你的代码,但 checkify
也可以自动添加对常见错误的检查!考虑以下错误情况:
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
默认情况下,checkify
仅解除 checkify.check
的影响,而不会执行任何操作来捕获上述错误。但是,如果你要求它这样做,checkify
也会自动使用检查来检测你的代码。
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
用于选择启用哪些自动检查的 API 基于 Set。有关详细信息,请参阅 jax.experimental.checkify
。
checkify
在 JAX 转换下。#
如上面的示例所示,checkify 后的函数可以愉快地进行 jit。以下是一些关于 checkify
与其他 JAX 转换结合使用的示例。请注意,checkify 后的函数是函数式纯洁的,应该可以轻松地与所有 JAX 转换组合使用!
jit
#
你可以安全地将 jax.jit
添加到 checkify 后的函数,或者对 jit 后的函数进行 checkify
,两者都可以工作。
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap
/pmap
#
你可以 vmap
和 pmap
checkify 后的函数(或对映射后的函数进行 checkify
)。映射 checkify 后的函数会给你一个映射后的错误,该错误可能包含映射维度的每个元素的不同错误。
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
但是,对 vmap 进行 checkify 将产生一个单一的(未映射的)错误!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit
#
checkify 后的函数的 pjit
可以正常工作,你只需要为错误值输出指定一个额外的 out_axis_resources
的 None
。
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad
#
如果对 grad 进行 checkify,你的梯度计算也将被检测。
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
请注意,f
中没有乘法,但在其梯度计算中存在乘法(这是生成 NaN 的地方!)。因此,使用 checkify-of-grad 为前向和后向传递操作添加自动检查。
checkify.check
将仅应用于你的函数的原始值。如果你想在梯度值上使用 check
,请使用 custom_vjp
。
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
jax.experimental.checkify
的优点和局限性#
优点#
你可以在任何地方使用它(错误是“只是值”,并且在转换(如其他值)下行为直观)
自动检测:你无需对代码进行本地修改。相反,
checkify
可以检测所有代码!
局限性#
添加大量运行时检查可能会很昂贵(例如,对每个原语添加 NaN 检查会向你的计算中添加大量操作)
需要将错误值线程化输出函数并手动抛出错误。如果未显式抛出错误,你可能会错过错误!
抛出错误值会将该错误值在主机上实现,这意味着它是一个阻塞操作,它会破坏 JAX 的异步运行超前。