jax.lax.platform_dependent#
- jax.lax.platform_dependent(*args, default=None, **per_platform)[源代码]#
分阶段输出特定于平台的代码。
在 JAX 中,实际运行计算的平台是在很晚的时候才确定的,例如,基于数据所在的位置。当使用 AOT 降低或序列化时,计算可能会在不同的机器上编译和执行,甚至在降低时不可用的平台上执行。这意味着使用 Python 条件语句编写平台相关代码是不安全的,例如,基于当前的默认 JAX 平台。相反,可以使用
platform_dependent
用法
def cpu_code(*args): ... def tpu_code(*args): ... def other_platforms_code(*args): ... res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code, default=other_platforms_code)
当分阶段输出的代码在 CPU 上执行时,等效于
cpu_code(*args)
;在 TPU 上执行时,等效于tpu_code(*args)
;在任何其他平台上执行时,等效于other_platforms_code(*args)
。与 Python 条件语句不同,所有备选方案都会被追踪并分阶段输出到 Jaxpr。这类似于switch()
,并且是在其基础上实现的,因此继承了其在转换下的行为。与
switch()
不同的是,执行哪个分支的选择会更早地做出:在大多数情况下,当已知降低平台时,会在降低期间做出选择;在极少数的多平台降低和序列化的情况下,StableHLO 代码将包含实际平台的条件语句。此条件语句会在编译之前、已知编译平台时及时解析。这意味着编译器实际上从未看到条件语句。- 参数:
*args (Any) – 传递给每个分支的 JAX 数组。可以是 PyTree。
**per_platform (Callable[..., _T]) – 用于不同平台的分支。这些分支是使用
*args
调用的 JAX 可调用对象。关键字是平台名称,例如“cpu”、“tpu”、“cuda”、“rocm”。default (Callable[..., _T] | None | None) – 可选的默认分支,用于未在
per_platform
中提及的平台。如果不存在default
,则当为未在per_platform
中提及的平台降低代码时,将会发生错误。
- 返回值:
值
per_platform[execution_platform](*args)
。