jax.export 模块#

jax.export 是一个用于导出和序列化 JAX 函数以进行持久存档的库。

请参阅导出和序列化文档。

#

class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[源代码]#

一个降级为 StableHLO 的 JAX 函数。

参数:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

fun_name#

导出的函数的名称,用于错误消息。

类型:

str

in_tree#

一个 PyTreeDef,描述了降级的 JAX 函数的元组 (args, kwargs)。实际的降级不依赖于 in_tree,但可以使用它来使用相同的参数结构调用导出的函数。

类型:

tree_util.PyTreeDef

in_avals#

输入抽象值的扁平元组。形状中可能包含维度表达式。

类型:

tuple[core.ShapedArray, …]

out_tree#

一个 PyTreeDef,描述了降级的 JAX 函数的结果。

类型:

tree_util.PyTreeDef

out_avals#

输出抽象值的扁平元组。形状中可能包含维度表达式,其维度变量与 in_avals 中的维度变量相同。

类型:

tuple[core.ShapedArray, …]

in_shardings_hlo#

扁平化的输入分片,一个长度与 in_avals 相同的序列。None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可用于 JAX API 的分片规范,请参阅 in_shardings_jax

类型:

tuple[HloSharding | None, …]

out_shardings_hlo#

扁平化的输出分片,一个长度与 out_avals 相同的序列。None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可用于 JAX API 的分片规范,请参阅 out_shardings_jax

类型:

tuple[HloSharding | None, …]

nr_devices#

该模块已降级到的设备数量。

类型:

int

platforms#

一个元组,其中包含应导出该函数的平台。JAX 中的平台集是开放的;用户可以添加平台。JAX 内置平台为:“tpu”、“cpu”、“cuda”、“rocm”。请参阅 https://jax.ac.cn/en/latest/export/export.html#cross-platform-and-multi-platform-export

类型:

tuple[str, …]

ordered_effects#

序列化模块中存在的有序效果。从序列化版本 9 开始存在。有关存在有序效果时的调用约定,请参阅 https://jax.ac.cn/en/latest/export/export.html#module-calling-convention

类型:

tuple[effects.Effect, …]

unordered_effects#

序列化模块中存在的无序效果。从序列化版本 9 开始存在。

类型:

tuple[effects.Effect, …]

mlir_module_serialized#

序列化的降级 VHLO 模块。

类型:

bytes

calling_convention_version#

导出模块的调用约定的版本号。有关更多版本控制详细信息,请参阅 https://jax.ac.cn/en/latest/export/export.html#calling-convention-versions

类型:

int

module_kept_var_idx#

必须传递给模块的 in_avals 中参数的已排序索引。其他参数已被删除,因为它们未使用。

类型:

tuple[int, …]

uses_global_constants#

mlir_module_serialized 是否使用形状多态性或多平台导出。这可能是因为 in_avals 包含维度变量,或者由于具有维度变量或平台索引参数的导出模块的内部调用。此类模块在 XLA 编译之前需要形状细化。

类型:

bool

disabled_safety_checks#

一个描述导出时已禁用安全检查的描述符列表。请参阅 DisabledSafetyCheck 的文档字符串。

类型:

Sequence[DisabledSafetyCheck]

_get_vjp#

一个可选函数,它接收当前导出的函数并返回导出的 VJP 函数。VJP 函数接收一个扁平的参数列表,首先是原始参数,然后是每个原始输出的余切参数。它返回一个元组,其中包含与扁平化的原始输入相对应的余切。

类型:

Callable[[Exported], Exported] | None

请参阅 [关于 mlir_module 的调用约定的说明](https://jax.ac.cn/en/latest/export/export.html#module-calling-convention)。

call(*args, **kwargs)[源代码]#

从 JAX 程序调用导出的函数。

参数:
  • args – 要传递给导出函数的位置参数。这应该是一个数组的 pytree,其 pytree 结构与导出该函数的参数的结构相同。

  • kwargs – 要传递给导出函数的关键字参数。

返回:一个结果数组的 pytree,其结构与

导出函数的结果相同。

调用支持反向模式 AD,以及导出支持的所有功能:形状多态性、多平台、设备多态性。请参阅 [JAX 导出文档](https://jax.ac.cn/en/latest/export/export.html)中的示例。

has_vjp()[源代码]#

返回此导出的函数是否支持 VJP。

返回类型:

bool

in_shardings_jax(mesh)[源代码]#

创建与 self.in_shardings_hlo 对应的 Shardings。

导出的对象将 in_shardings_hlo 存储为 HloShardings,它们独立于网格或设备集。此方法构建可在 JAX API(如 jax.jitjax.device_put)中使用的 Sharding。

用法示例

>>> from jax import export
>>> # Prepare the exported object:
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
...                             in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
...     )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
>>> # Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>> # Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
...     exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
 Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
 Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
 Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
 Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
 Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
 Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
 Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
参数:

mesh (sharding.Mesh)

返回类型:

Sequence[sharding.Sharding | None]

mlir_module()[源代码]#

mlir_module_serialized 的字符串表示形式。

返回类型:

str

out_shardings_jax(mesh)[源代码]#

创建与 self.out_shardings_hlo 对应的 Sharding。

请参阅 in_shardings_jax 的文档。

参数:

mesh (sharding.Mesh)

返回类型:

Sequence[sharding.Sharding | None]

serialize(vjp_order=0)[源代码]#

序列化一个 Exported 对象。

参数:

vjp_order (int) – 要包含的最大 vjp 阶数。例如,值 2 表示我们序列化原始函数和 vjp 函数的两个阶数。这应该允许反序列化函数的二阶反向模式微分。即,jax.grad(jax.grad(f)).

返回类型:

bytearray

vjp()[源代码]#

获取导出的 VJP。

如果不可用,则返回 None,如果 Exported 对象是从没有 VJP 的外部格式加载的,则可能会发生这种情况。

返回类型:

Exported

class jax.export.DisabledSafetyCheck(_impl)[源代码]#

一个在(反)序列化时应跳过的安全检查。

大多数此类检查在序列化时执行,但有些检查会延迟到反序列化时执行。禁用检查的列表附加到序列化,例如,作为 jax.export.Exportedtf.XlaCallModuleOp 的字符串属性序列。

当使用 jax2tf 时,可以通过传递 TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform 来禁用更多反序列化安全检查。

参数:

_impl (str)

classmethod custom_call(target_name)[源代码]#

允许序列化已知不稳定的调用目标。

仅对序列化有效。 :param target_name: 要允许的自定义调用目标的名称。

参数:

target_name (str)

返回类型:

DisabledSafetyCheck

is_custom_call()[源代码]#

返回此指令允许的自定义调用目标。

返回类型:

str | None

classmethod platform()[源代码]#

允许编译平台与导出平台不同。

仅对反序列化有效。

返回类型:

DisabledSafetyCheck

函数#

export(fun_jit, *[, platforms, disabled_checks])

导出 JAX 函数以进行持久序列化。

deserialize(blob)

反序列化一个 Exported 对象。

minimum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

maximum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

default_export_platform()

检索默认导出平台。

register_pytree_node_serialization(nodetype, ...)

注册用于序列化和反序列化的自定义 PyTree 节点。

register_namedtuple_serialization(nodetype, ...)

注册用于序列化和反序列化的 namedtuple。

常量#

jax.export.minimum_supported_serialization_version#

支持的最小序列化版本;请参阅调用约定版本

jax.export.maximum_supported_serialization_version#

支持的最大序列化版本;请参阅调用约定版本