jax.stages
模块#
编译执行过程各阶段的接口。
JAX 转换(例如 jax.jit
和 jax.pmap
)支持即时编译执行,同时也支持一种常见的提前显式降级和编译的方法。此模块定义了代表此过程各个阶段的类型。
更多信息,请参阅 AOT 演练。
类#
- class jax.stages.Wrapped(*args, **kwargs)[源代码]#
一个准备好进行跟踪、降级和编译的函数。
此协议反映了诸如
jax.jit
等函数的输出。调用它会导致 JIT(即时)降级、编译和执行。它也可以在编译之前显式降级,并且可以在执行之前编译结果。
- class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[源代码]#
针对参数类型和值专门化的函数降级。
降级是准备好进行编译的计算。此类将降级与稍后编译和执行它所需的其余信息一起携带。它还为查询 JAX 各个降级路径(
jit()
、pmap()
等)上的降级计算属性提供了一个通用 API。- 参数:
lowering (XlaLowering)
args_info (Any)
out_tree (tree_util.PyTreeDef)
no_kwargs (bool)
- as_text(dialect=None, *, debug_info=False)[源代码]#
此降级的可读文本表示形式。
旨在用于可视化和调试目的。这不必是有效或可靠的序列化。如果需要可靠且可移植的序列化,请使用 jax.export。
- compile(compiler_options=None)[源代码]#
编译,返回相应的
Compiled
实例。- 参数:
compiler_options (CompilerOptions | None | None)
- 返回类型:
- compiler_ir(dialect=None)[源代码]#
此降级的任意对象表示形式。
旨在用于调试目的。这不是有效或可靠的序列化。输出不能保证在各个调用之间保持一致。如果需要可靠且可移植的序列化,请使用 jax.export。
如果不可用,则返回
None
,例如基于后端、编译器或运行时。- 参数:
dialect (str | None | None) – 可选字符串,指定降级方言(例如“stablehlo”或“hlo”)。
- 返回类型:
Any | None
- class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[源代码]#
针对类型/值专门化的函数的已编译表示形式。
编译的计算与可执行文件以及执行它所需的其余信息相关联。它还为查询 JAX 各个编译路径和后端上的编译计算属性提供了一个通用 API。
- 参数:
args_info (Any)
out_tree (tree_util.PyTreeDef)
- as_text()[源代码]#
此可执行文件的可读文本表示形式。
旨在用于可视化和调试目的。这不是有效或可靠的序列化。
如果不可用,则返回
None
,例如基于后端、编译器或运行时。- 返回类型:
str | None
- cost_analysis()[源代码]#
执行成本估算的摘要。
旨在用于可视化和调试目的。此对象输出是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶子的嵌套字典、列表和元组)。但是,其结构可以是任意的:它可能在 JAX 和 jaxlib 的不同版本之间,甚至在各个调用之间不一致。
如果不可用,则返回
None
,例如基于后端、编译器或运行时。- 返回类型:
Any | None