jax.experimental.checkify.checkify#
- jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[source]#
将 fun 中的 check 调用函数化,并可选地添加运行时错误检查。
运行时错误要么是用户添加的
check()
断言,要么是根据errors
参数自动添加的检查,例如 NaN 检查。返回的函数将返回一个错误对象 err 以及原始函数的输出。
err.get()
将返回None
(如果未发生错误)或包含错误消息的字符串。此错误消息将对应于发生的第一个错误。err.throw()
将引发一个 ValueError,其中包含错误消息(如果发生错误)。默认情况下,仅启用用户添加的
check()
断言。您可以通过errors
参数启用自动检查。- 自动检查可以启用哪些检查,以及何时生成错误。
user_checks
:评估结果为 False 的check()
。nan_checks
:浮点运算生成了 NaN 值作为输出。div_checks
:除以零。index_checks
:索引超出范围。
可以通过传递错误的 Set(例如
errors=nan_checks
)来同时启用多个类别。多个集合可以重新组合(例如errors=float_checks|user_checks
)。- 参数:
- 返回值:
一个函数,它接受与
fun
相同的参数,并作为输出返回一个对,其中第一个元素是Error
值,表示第一个失败的check()
,第二个元素是fun
的原始输出。- 返回类型:
例如
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin