jax.experimental.pjit 模块#

API#

jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[源代码]#

使 fun 被编译并自动在多个设备上进行分区。

注意:此函数现在等效于 jax.jit,请改用该函数。返回的函数具有与 fun 等效的语义,但会被编译成在多个设备(例如,多个 GPU 或多个 TPU 核心)上运行的 XLA 计算。如果 fun 的 jit 版本无法容纳在单个设备的内存中,或者为了通过在多个设备上并行运行每个操作来加速 fun,则此方法会很有用。

设备上的分区是根据 in_shardings 中指定的输入分区和 out_shardings 中指定的输出分区的传播自动进行的。这两个参数中指定的资源必须引用网格轴,如 jax.sharding.Mesh() 上下文管理器所定义。 请注意,在 pjit() 应用时网格定义会被忽略,返回的函数将使用每个调用站点可用的网格定义。

如果 pjit() 函数的输入没有根据 in_shardings 正确分区,则会自动在设备上进行分区。在某些情况下,确保输入已经正确地预分区可以提高性能。例如,如果将一个 pjit() 函数的输出传递给另一个 pjit() 函数(或循环中的同一个 pjit() 函数),请确保相关的 out_shardings 与相应的 in_shardings 匹配。

注意

多进程平台: 在 TPU pod 等多进程平台上,pjit() 可用于跨进程在所有可用设备上运行计算。 为了实现这一点,pjit() 被设计用于 SPMD Python 程序,其中每个进程都运行相同的 Python 代码,以便所有进程按相同顺序运行相同的 pjit() 函数。

在此配置中运行时,网格应包含所有进程中的设备。所有输入参数都必须是全局形状。fun 仍然会在网格中的所有设备(包括来自其他进程的设备)上执行,并将获得跨多个进程的数据的全局视图,将其视为单个数组。

SPMD 模型还要求所有进程必须按相同顺序运行相同的多进程 pjit() 函数,但它们可以与在单个进程中运行的任意操作交错。

参数:
  • fun (Callable) – 要编译的函数。 应该是纯函数,因为副作用可能只执行一次。 它的参数和返回值应该是数组、标量或它们的(嵌套)标准 Python 容器(元组/列表/字典)。 由 static_argnums 指示的位置参数可以是任何内容,前提是它们是可哈希的并且具有已定义的相等操作。 静态参数包含在编译缓存键中,这就是为什么必须定义哈希和相等运算符的原因。

  • in_shardings

    fun 的参数结构匹配的 Pytree,所有实际参数都被资源分配规范替换。 指定 Pytree 前缀(例如,用一个值代替整个子树)也是有效的,在这种情况下,叶子会广播到该子树中的所有值。

    in_shardings 参数是可选的。 JAX 将从输入的 jax.Array 中推断分片,如果无法推断分片,则默认复制输入。

    有效的资源分配规范是

    • Sharding,它将决定如何对值进行分区。 使用此方法,不需要使用网格上下文管理器。

    • None 是一种特殊情况,其语义如下
      • 如果提供网格上下文管理器,则 JAX 可以自由选择它想要的任何分片。 对于 in_shardings,JAX 会将其标记为已复制,但此行为将来可能会更改。 对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

      • 如果提供了网格上下文管理器,则 None 将表示该值将在网格的所有设备上复制。

    • 为了向后兼容,in_shardings 仍然支持摄取 PartitionSpec。此选项只能与网格上下文管理器一起使用。

      • PartitionSpec,一个长度最多等于分区值的秩的元组。每个元素可以是 None、网格轴或网格轴的元组,并指定分配给分区值维度(与规范中的位置匹配)的资源集。

    每个维度的大小都必须是分配给它的资源总数的倍数。

  • out_shardings – 与 in_shardings 类似,但指定函数输出的资源分配。out_shardings 参数是可选的。如果未指定,jax.jit() 将使用 GSPMD 的分片传播来确定如何对输出进行分片。

  • static_argnums (int | Sequence[int] | None | None) –

    一个可选的整数或整数集合,用于指定将哪些位置参数视为静态(编译时常量)。 仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。

    静态参数应该是可哈希的,这意味着实现了 __hash____eq__,并且是不可变的。 使用这些常量的不同值调用 jit 函数将触发重新编译。 不是数组或其容器的参数必须标记为静态。

    如果未提供 static_argnums,则不会将任何参数视为静态。

  • static_argnames ( str | Iterable[str] | None | None) – 一个可选的字符串或字符串集合,用于指定哪些命名参数应被视为静态的(编译时常量)。有关详细信息,请参阅 static_argnums 的注释。如果未提供,但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums ( int | Sequence[int] | None | None) –

    指定哪些位置参数缓冲区“捐赠”给计算。如果在计算完成后不再需要参数缓冲区,则可以安全地捐赠它们。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应该重用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将会引发错误。默认情况下,不捐赠任何参数缓冲区。

    如果既未提供 donate_argnums 也未提供 donate_argnames,则不捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,或者反之,JAX 使用 inspect.signature(fun) 来查找与 donate_argnames 对应的任何位置参数(或反之)。如果同时提供了 donate_argnumsdonate_argnames,则不会使用 inspect.signature,并且只会捐赠在 donate_argnumsdonate_argnames 中列出的实际参数。

    有关缓冲区捐赠的更多详细信息,请参阅 FAQ

  • donate_argnames ( str | Iterable[str] | None | None) – 一个可选的字符串或字符串集合,用于指定哪些命名参数被捐赠给计算。有关详细信息,请参阅 donate_argnums 的注释。如果未提供,但设置了 donate_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • keep_unused ( bool) – 如果为 False (默认值),则 JAX 确定为 fun 未使用的参数可能会从生成的已编译 XLA 可执行文件中删除。这些参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。

  • device (xc.Device | None | None) – 此参数已弃用。请在将参数传递给 jit 之前,将参数放在您想要的设备上。可选,jit 函数将在其上运行的设备。(可以通过 jax.devices() 获取可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend ( str | None | None) – 此参数已弃用。请在将参数传递给 jit 之前,将参数放在您想要的后端上。可选,一个表示 XLA 后端的字符串:'cpu', 'gpu', 或 'tpu'

  • inline ( bool)

  • abstracted_axes (Any | None | None)

  • compiler_options ( dict[str, Any] | None | None)

返回:

一个经过包装的 fun 版本,设置为即时编译,并根据每次调用时可用的网格自动分区。

返回类型:

JitWrapped

例如,可以通过单个 pjit() 应用程序在任意设备集上自动分区卷积运算符

>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import Mesh, PartitionSpec
>>> from jax.experimental.pjit import pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
...         in_shardings=None, out_shardings=PartitionSpec('devices'))
>>> with Mesh(np.array(jax.devices()), ('devices',)):
...   print(f(x))  
[ 0.5  2.   4.   6.   8.  10.  12.  10. ]