jax.lax.platform_dependent#

jax.lax.platform_dependent(*args, default=None, **per_platform)[源代码]#

分阶段输出平台特定代码。

在 JAX 中,计算实际运行的平台是在很晚的时候才确定的,例如,基于数据所在的位置。当使用 AOT 降阶或序列化时,计算可能会在不同的机器上编译和执行,甚至在降阶时不可用的平台上编译和执行。这意味着使用 Python 条件语句编写平台相关的代码是不安全的,例如,基于当前默认的 JAX 平台。相反,可以使用 platform_dependent

用法

def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
                         default=other_platforms_code)

当分阶段输出的代码在 CPU 上执行时,这等效于 cpu_code(*args),在 TPU 上等效于 tpu_code(*args),在任何其他平台上等效于 other_platforms_code(*args)。与 Python 条件语句不同,所有备选项都会被跟踪并分阶段输出到 Jaxpr。这类似于 switch(),并且是根据它实现的,它也继承了在转换下的行为。

switch() 不同,选择执行什么是在更早的时候进行的:在大多数情况下,当知道降阶平台时,在降阶期间进行;在多平台降阶和序列化的罕见情况下,StableHLO 代码将包含对实际平台的条件判断。此条件在编译平台已知时,在编译之前的恰当时机解析。这意味着编译器实际上永远不会看到条件。

参数:
  • *args (Any) – 传递给每个分支的 JAX 数组。可能是 PyTree。

  • **per_platform (Callable[..., _T]) – 用于不同平台的分支。这些分支是使用 *args 调用的 JAX 可调用对象。关键字是平台名称,例如,“cpu”、“tpu”、“cuda”、“rocm”。

  • default (Callable[..., _T] | None | None) – 可选的默认分支,用于 per_platform 中未提及的平台。如果不存在 default,则在为 per_platform 中未提及的平台降阶代码时会发生错误。

返回:

per_platform[execution_platform](*args)