jax.lax.platform_dependent

jax.lax.platform_dependent#

jax.lax.platform_dependent(*args, default=None, **per_platform)[source]#

分阶段执行平台特定代码。

在 JAX 中,运行计算的实际平台是在很晚的时候确定的,例如,基于数据所在的 위치。使用 AOT 降低或序列化时,计算可能在不同的机器上编译和执行,甚至可能在降低时不可用的平台上执行。这意味着使用 Python 条件语句(例如,基于当前默认 JAX 平台)编写平台特定代码是不安全的。相反,可以使用 platform_dependent

用法

def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
                         default=other_platforms_code)

当分阶段执行的代码在 CPU 上执行时,这相当于 cpu_code(*args),在 TPU 上执行相当于 tpu_code(*args),而在任何其他平台上则相当于 other_platforms_code(*args)。与 Python 条件不同,所有备选方案都将被追踪并分阶段执行到 Jaxpr。这类似于,并且在实现上依赖于,switch(),它继承了该函数在变换下的行为。

switch() 不同,执行代码的选择是在更早的阶段做出的:大多数情况下是在降低阶段,此时降低平台已知;在罕见的多平台降低和序列化情况下,StableHLO 代码将包含一个基于实际平台的条件。此条件会在编译平台已知时,在编译前实时解析。这意味着编译器实际上从未看到过条件语句。

参数:
  • *args (Any) – 传递给每个分支的 JAX 数组。可以是 PyTrees。

  • **per_platform (Callable[..., _T]) – 用于不同平台的分支。这些分支是使用 *args 调用的 JAX 可调用对象。关键字是平台名称,例如 ‘cpu’,‘tpu’,‘cuda’,‘rocm’。

  • default (Callable[..., _T] | None | None) – 可选的默认分支,用于未在 per_platform 中提到的平台。如果不存在 default,则当代码针对未在 per_platform 中提到的平台降低时,会发生错误。

返回值:

per_platform[execution_platform](*args)