jax.debug.callback#
- jax.debug.callback(callback, *args, ordered=False, **kwargs)[源代码]#
调用一个可分阶段的 Python 回调函数。
更多说明,请参阅外部回调。
jax.debug.callback
使您能够传入一个 Python 函数,该函数可以在分阶段的 JAX 程序中调用。jax.debug.callback
遵循现有的 JAX 转换纯操作语义,因此不知道副作用。这意味着在存在高阶原语和转换的情况下,该效果可能会被删除、复制或潜在地重新排序。我们希望有这种行为,因为我们希望
jax.debug.callback
是“无害的”,也就是说,我们希望这些原语尽可能少地改变 JAX 计算,同时尽可能多地揭示有关它们的信息,例如计算的哪些部分被复制或删除。- 参数:
callback (Callable[..., None]) – 一个返回 None 的 Python 可调用对象。
*args (Any) – 回调函数的位置参数。
ordered (bool) – 一个仅关键字参数,用于指示分阶段计算是否将强制执行此回调相对于其他有序回调的顺序。
**kwargs (Any) – 回调函数的关键字参数。
- 返回:
None
- 返回类型:
None
另请参阅
jax.experimental.io_callback()
:为不纯函数设计的回调函数。jax.pure_callback()
:为纯函数设计的回调函数。jax.debug.print()
:为打印设计的回调函数。