jax.export
模块#
jax.export
是一个用于导出和序列化 JAX 函数以进行持久存档的库。
请参阅导出和序列化文档。
类#
- class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[源代码]#
一个降低为 StableHLO 的 JAX 函数。
- 参数:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
一个 PyTreeDef,描述降低的 JAX 函数的元组 (args, kwargs)。实际的降低不依赖于 in_tree,但这可以用来使用相同的参数结构调用导出的函数。
- 类型:
tree_util.PyTreeDef
- out_tree#
一个 PyTreeDef,描述降低的 JAX 函数的结果。
- 类型:
tree_util.PyTreeDef
- in_shardings_hlo#
扁平化的输入分片,一个与 in_avals 一样长的序列。None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可与 JAX API 一起使用的分片规范,请参阅 in_shardings_jax。
- 类型:
tuple[HloSharding | None, …]
- out_shardings_hlo#
扁平化的输出分片,一个与 out_avals 一样长的序列。None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可与 JAX API 一起使用的分片规范,请参阅 out_shardings_jax。
- 类型:
tuple[HloSharding | None, …]
- platforms#
一个包含应导出函数的平台的元组。JAX 中的平台集是开放式的;用户可以添加平台。JAX 内置平台有:“tpu”、“cpu”、“cuda”、“rocm”。请参阅 https://jax.ac.cn/en/latest/export/export.html#cross-platform-and-multi-platform-export。
- ordered_effects#
序列化模块中存在的有序效应。这是从序列化版本 9 开始存在的。有关存在有序效应时的调用约定,请参阅 https://jax.ac.cn/en/latest/export/export.html#module-calling-convention。
- 类型:
tuple[effects.Effect, …]
- calling_convention_version#
导出模块的调用约定的版本号。有关更多版本控制详细信息,请参阅 https://jax.ac.cn/en/latest/export/export.html#calling-convention-versions。
- 类型:
- uses_global_constants#
mlir_module_serialized 是否使用形状多态性或多平台导出。这可能是因为 in_avals 包含维度变量,或者由于对具有维度变量或平台索引参数的导出模块的内部调用。此类模块在 XLA 编译之前需要进行形状细化。
- 类型:
- disabled_safety_checks#
导出时禁用的安全检查的描述符列表。请参阅 DisabledSafetyCheck 的文档字符串。
- 类型:
Sequence[DisabledSafetyCheck]
- _get_vjp#
一个可选函数,它接受当前导出的函数并返回导出的 VJP 函数。VJP 函数接受一个参数的扁平列表,从原始参数开始,然后是每个原始输出的余切参数。它返回一个元组,其中包含与扁平化原始输入相对应的余切。
请参阅 [mlir_module 的调用约定说明](https://jax.ac.cn/en/latest/export/export.html#module-calling-convention)。
- call(*args, **kwargs)[源代码]#
调用 JAX 程序中导出的函数。
- 参数:
args – 传递给导出函数的位置参数。这应该是一个数组的 pytree,其 pytree 结构与导出该函数所用的参数相同。
kwargs – 传递给导出函数的关键字参数。
- 返回:一个结果数组的 pytree,其结构与
导出函数的结果相同。
调用支持反向模式 AD,以及导出支持的所有功能:形状多态性、多平台、设备多态性。请参阅 [JAX 导出文档](https://jax.ac.cn/en/latest/export/export.html) 中的示例。
- in_shardings_jax(mesh)[源代码]#
创建与 self.in_shardings_hlo 对应的 Shardings。
Exported 对象将 in_shardings_hlo 存储为 HloShardings,它们独立于网格或设备集。此方法构造可用于 JAX API(例如 jax.jit 或 jax.device_put)中的 Sharding。
使用示例
>>> from jax import export >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) ... )(np.arange(jax.device_count())) >>> exp.in_shardings_hlo ({devices=[8]<=[8]},) >>> # Create a mesh for running the exported object >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) >>> # Put the args and kwargs on the appropriate devices >>> run_arg = jax.device_put(np.arange(jax.device_count()), ... exp.in_shardings_jax(run_mesh)[0]) >>> res = exp.call(run_arg) >>> res.addressable_shards [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
- 参数:
mesh (sharding.Mesh)
- 返回类型:
Sequence[sharding.Sharding | None]
- out_shardings_jax(mesh)[源代码]#
创建与 self.out_shardings_hlo 对应的 Shardings。
请参阅 in_shardings_jax 的文档。
- 参数:
mesh (sharding.Mesh)
- 返回类型:
Sequence[sharding.Sharding | None]
- class jax.export.DisabledSafetyCheck(_impl)[源代码]#
一个在(反)序列化时应跳过的安全检查。
大多数此类检查在序列化时执行,但有些会延迟到反序列化时执行。禁用的检查列表附加到序列化中,例如作为 jax.export.Exported 或 tf.XlaCallModuleOp 的字符串属性序列。
使用 jax2tf 时,可以通过传递 TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform 来禁用更多反序列化安全检查。
- 参数:
_impl (str)
- classmethod custom_call(target_name)[源代码]#
允许序列化已知不稳定的调用目标。
仅在序列化时有效。:param target_name: 要允许的自定义调用目标的名称。
- 参数:
target_name (str)
- 返回类型:
函数#
|
导出 JAX 函数以进行持久序列化。 |
|
反序列化一个 Exported 对象。 |
int([x]) -> 整数 int(x, base=10) -> 整数 |
|
int([x]) -> 整数 int(x, base=10) -> 整数 |
|
检索默认导出平台。 |
|
|
注册自定义 PyTree 节点以进行序列化和反序列化。 |
|
注册命名元组以进行序列化和反序列化。 |