checkify 转换#

摘要: Checkify 允许您向 JAX 代码添加可 jit 的运行时错误检查(例如,越界索引)。将 checkify.checkify 转换与类似断言的 checkify.check 函数一起使用,以向 JAX 代码添加运行时检查

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

您还可以使用 checkify 自动添加常见检查

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!

函数化检查#

类似断言的 check API 本身不是纯函数式的:它可以像断言一样引发 Python 异常作为副作用。因此,它不能与 jitpmappjitscan 一起被分解

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.

但是 checkify 转换会将这些影响函数化(或消除)。经过 checkify 转换的函数会返回一个错误作为新的输出,并保持纯函数式。这种函数化意味着经过 checkify 转换的函数可以随意与分解/转换组合

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""

为什么 JAX 需要 checkify?#

在某些 JAX 转换下,您可以使用普通的 Python 断言来表达运行时错误检查,例如,当仅使用 jax.gradjax.numpy

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)

jax.grad(f)(0.)
# ValueError: "must be positive!"

但是普通的断言在 jitpmappjitscan 内部不起作用。在这些情况下,数值计算是被分解而不是在 Python 执行期间被急切地评估,因此数值不可用

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."

JAX 转换语义依赖于函数纯度,尤其是在组合多个转换时,那么我们如何在不破坏这一切的情况下提供错误机制呢?除了需要一个新的 API 之外,情况更加棘手:XLA HLO 不支持断言或抛出错误,所以即使我们有一个能够分解断言的 JAX API,我们又如何将这些断言降低为 XLA 呢?

您可以想象手动向您的函数添加运行时检查并管道输出表示错误的值

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result

err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!"

错误是函数计算的常规值,并且错误在 f_checked 之外引发。f_checked 是纯函数式的,因此我们通过构造知道它已经可以与 jit、pmap、pjit、scan 和所有 JAX 转换一起使用。唯一的问题是这种管道输送可能会很麻烦!

checkify 会为您进行此重写:这包括通过函数管道输出错误值,将检查重写为布尔运算并将结果与跟踪的错误值合并,并将最终错误值作为 checkified 函数的输出返回

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))

我们称之为函数化或消除调用 check 引入的影响。(在上面的“手动”示例中,错误值只是一个布尔值。checkify 的错误值在概念上相似,但也跟踪错误消息并公开 throw 和 get 方法;请参阅 jax.experimental.checkify)。checkify.check 还允许您通过将运行时值作为错误消息的格式参数提供,将其添加到错误消息中。

您现在可以手动使用运行时检查来检测您的代码,但是 checkify 还可以自动添加对常见错误的检查!考虑以下错误情况

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero

默认情况下,checkify 仅会消除 checkify.check,而不会执行任何操作来捕获上述错误。但是,如果您要求它这样做,checkify 也会自动使用检查来检测您的代码。

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

用于选择启用哪些自动检查的 API 基于 Sets。有关更多详细信息,请参阅 jax.experimental.checkify

JAX 转换下的 checkify#

如上面的示例所示,checkified 函数可以安全地进行 jit 处理。以下是一些 checkify 与其他 JAX 转换的更多示例。请注意,checkified 函数是纯函数式的,并且应该可以轻松地与所有 JAX 转换组合!

jit#

您可以安全地将 jax.jit 添加到 checkified 函数,或者对 jit 处理的函数进行 checkify,两者都可以工作。

def f(x, i):
  return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)

vmap/pmap#

您可以 vmappmap checkified 函数(或 checkify 映射函数)。映射 checkified 函数将为您提供一个映射的错误,该错误可以包含映射维度的每个元素的不同错误。

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
  at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
  at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""

但是,checkify-of-vmap 将产生一个单一的(未映射)错误!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))

pjit#

checkified 函数的 pjit 只是工作,您只需要为错误值输出指定一个额外的 out_axis_resourcesNone

def f(x):
  return x / x

f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))

with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)

grad#

如果您使用 checkify-of-grad,您的梯度计算也将被检测

def f(x):
 return x / (1 + jnp.sqrt(x))

grad_f = jax.grad(f)

err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)

请注意,f 中没有乘法,但是其梯度计算中有一个乘法(这就是生成 NaN 的地方!)。因此,请使用 checkify-of-grad 向前向和后向传递操作添加自动检查。

checkify.check 仅应用于您函数的原始值。如果您想在梯度值上使用 check,请使用 custom_vjp

@jax.custom_vjp
def assert_gradient_negative(x):
 return x

def fwd(x):
 return assert_gradient_negative(x), None

def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)

assert_gradient_negative.defvjp(fwd, bwd)

jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!

jax.experimental.checkify 的优点和局限性#

优点#

  • 您可以在任何地方使用它(错误是“只是值”,并且在诸如其他值之类的转换下以直观的方式运行)

  • 自动检测:您不需要对代码进行本地修改。相反,checkify 可以检测所有代码!

局限性#

  • 添加大量运行时检查可能会很昂贵(例如,向每个原语添加 NaN 检查会将大量操作添加到您的计算中)

  • 需要将错误值从函数中进行管道输出并手动抛出错误。如果错误未被明确抛出,您可能会错过错误!

  • 抛出错误值将在主机上实例化该错误值,这意味着这是一个阻塞操作,会破坏 JAX 的异步提前运行。