jax.stages 模块#

编译执行过程各个阶段的接口。

JAX 转换(如 jax.jitjax.pmap)在执行时进行即时编译,也支持一种提前显式降级和编译的通用方法。此模块定义了代表此过程各个阶段的类型。

有关更多信息,请参阅 AOT 演练

#

class jax.stages.Wrapped(*args, **kwargs)[源代码]#

一个准备好被跟踪、降级和编译的函数。

此协议反映了诸如 jax.jit 之类的函数的输出。调用它会导致 JIT(即时)降级、编译和执行。它也可以在编译之前显式降级,并在执行之前编译结果。

__call__(*args, **kwargs)[源代码]#

执行包装的函数,根据需要降级和编译。

lower(*args, **kwargs)[源代码]#

显式地为给定的参数降级此函数。

降级的函数会从 Python 中分阶段移出,并转换为编译器的输入语言,可能以依赖于后端的方式进行。它已准备好进行编译,但尚未编译。

返回:

一个 Lowered 实例,表示降级。

返回类型:

Lowered

trace(*args, **kwargs)[源代码]#

显式地为给定的参数跟踪此函数。

跟踪的函数会从 Python 中分阶段移出,并转换为 jaxpr。它已准备好进行降级,但尚未降级。

返回:

一个 Traced 实例,表示跟踪。

返回类型:

Traced

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[源代码]#

专用于参数类型和值的函数的降级。

降级是准备好进行编译的计算。此类带有降级,以及稍后编译和执行它所需的其余信息。它还提供了一个通用 API,用于查询 JAX 各个降级路径(jit()pmap() 等)中已降级计算的属性。

参数:
  • lowering (XlaLowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None)[源代码]#

此降级的人类可读文本表示。

旨在用于可视化和调试。这不需要是有效或可靠的序列化。它直接转发给外部调用者。

参数:

dialect (str | None | None) – 可选字符串,指定降级方言(例如“stablehlo”)

返回类型:

str

compile(compiler_options=None)[源代码]#

编译,返回相应的 Compiled 实例。

参数:

compiler_options (CompilerOptions | None | None)

返回类型:

Compiled

compiler_ir(dialect=None)[源代码]#

此降级的任意对象表示。

旨在用于调试。这不是有效或可靠的序列化。输出不能保证在调用之间保持一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

参数:

dialect (str | None | None) – 可选字符串,指定降级方言(例如“stablehlo”)

返回类型:

Any | None

cost_analysis()[源代码]#

执行成本估计的摘要。

旨在用于可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶子的嵌套字典、列表和元组)。但是,其结构可以是任意的:在 JAX 和 jaxlib 的版本之间,甚至在调用之间,其结构可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[源代码]#

对(位置参数,关键字参数)的对的树结构。

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[源代码]#

专用于类型/值的函数的编译表示。

编译后的计算与可执行文件相关联,以及执行它所需的其余信息。它还提供了一个通用 API,用于查询 JAX 各个编译路径和后端中编译后的计算的属性。

参数:
  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[源代码]#

将自身作为函数调用。

as_text()[源代码]#

此可执行文件的可读文本表示形式。

旨在用于可视化和调试目的。这不是有效的或可靠的序列化。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

str | None

cost_analysis()[源代码]#

执行成本估计的摘要。

旨在用于可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶子的嵌套字典、列表和元组)。但是,其结构可以是任意的:在 JAX 和 jaxlib 的版本之间,甚至在调用之间,其结构可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[源代码]#

对(位置参数,关键字参数)的对的树结构。

memory_analysis()[源代码]#

内存需求估计的摘要。

旨在用于可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶子的嵌套字典、列表和元组)。但是,其结构可以是任意的:在 JAX 和 jaxlib 的版本之间,甚至在调用之间,其结构可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

runtime_executable()[源代码]#

此可执行文件的任意对象表示形式。

旨在用于调试目的。这不是有效的或可靠的序列化。输出无法保证在不同调用之间保持一致性。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None