jax.experimental.checkify.check#

jax.experimental.checkify.check(pred, msg, *fmt_args, debug=False, **fmt_kwargs)[源代码]#

检查一个断言,如果断言为 False,则添加一个带有 msg 的错误。

这是一个有副作用的操作,不能进行分段(jitted/scanned/…)。在对带有检查的函数进行分段之前,请使用 checkify() 函数进行处理!

参数:
  • pred (Bool) – 如果为 False,则添加 FailedCheckError 错误。

  • msg (str) – 如果添加错误,则为错误消息。可以是一个格式化字符串。

  • debug (bool) – 是否启用调试模式。如果为 True,则在执行期间将删除检查。如果为 False,则必须使用 checkify.checkify 将检查功能化。

  • fmt_argsmsg 的位置和关键字格式化参数,例如:check(.., "检查在值 {} {named_arg} 上失败", x, named_arg=y) 请注意,这些参数可以是跟踪值,允许您将运行时值添加到错误消息中。 请注意,跟踪这些运行时数组将增加您的内存使用量,即使没有发生错误。

  • fmt_kwargsmsg 的位置和关键字格式化参数,例如:check(.., "检查在值 {} {named_arg} 上失败", x, named_arg=y) 请注意,这些参数可以是跟踪值,允许您将运行时值添加到错误消息中。 请注意,跟踪这些运行时数组将增加您的内存使用量,即使没有发生错误。

返回类型:

None

例如

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "{x} needs to be positive!", x=x)
...   return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!