jax.experimental.pjit
模块#
API#
- jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[源代码]#
使
fun
被编译并自动跨多个设备进行分区。注意:此函数现在等效于 jax.jit,请改用它。返回的函数具有与
fun
语义等效的语义,但会被编译成在多个设备(例如,多个 GPU 或多个 TPU 核心)上运行的 XLA 计算。如果fun
的 jitted 版本无法装入单个设备的内存,或者为了通过在多个设备上并行运行每个操作来加速fun
,这会非常有用。设备上的分区会根据在
in_shardings
中指定的输入分区和在out_shardings
中指定的输出分区的传播自动进行。这两个参数中指定的资源必须引用网格轴,如jax.sharding.Mesh()
上下文管理器所定义。请注意,在pjit()
应用时网格定义会被忽略,并且返回的函数将使用每个调用站点上可用的网格定义。如果输入未基于
in_shardings
正确分区,则输入到pjit()
的函数将自动跨设备分区。在某些情况下,确保输入已正确预分区可以提高性能。例如,如果将一个pjit()
函数的输出传递给另一个pjit()
函数(或循环中的相同pjit()
函数),请确保相关的out_shardings
与相应的in_shardings
匹配。注意
多进程平台:在诸如 TPU pod 之类的多进程平台上,
pjit()
可用于跨进程在所有可用设备上运行计算。为了实现这一点,pjit()
被设计用于 SPMD Python 程序,其中每个进程都运行相同的 Python 代码,以便所有进程以相同的顺序运行相同的pjit()
函数。在此配置中运行时,网格应包含所有进程中的设备。所有输入参数必须是全局形状的。
fun
仍然会在网格中的所有设备上执行,包括来自其他进程的设备,并且会获得作为单个数组分布在多个进程中的数据的全局视图。SPMD 模型还要求在所有进程上以相同的顺序运行相同的多进程
pjit()
函数,但它们可以与在单个进程中运行的任意操作穿插。- 参数:
fun (Callable) – 要编译的函数。应该是一个纯函数,因为副作用可能只执行一次。它的参数和返回值应该是数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)。由
static_argnums
指示的位置参数可以是任何内容,只要它们是可哈希的并且定义了相等操作即可。静态参数作为编译缓存键的一部分包含,这就是为什么必须定义哈希和相等运算符的原因。in_shardings –
与
fun
参数结构匹配的 Pytree,所有实际参数都被资源分配规范替换。指定 Pytree 前缀(例如,用一个值代替整个子树)也是有效的,在这种情况下,叶子会广播到该子树中的所有值。in_shardings
参数是可选的。JAX 将从输入jax.Array
推断分片,如果无法推断分片,则默认为复制输入。有效的资源分配规范是
Sharding
,它将决定如何对值进行分区。使用它时,不需要使用网格上下文管理器。None
是一种特殊情况,其语义为如果未提供网格上下文管理器,则 JAX 可以自由选择它想要的任何分片。对于 in_shardings,JAX 将其标记为已复制,但此行为在将来可能会更改。对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
如果提供了网格上下文管理器,则 None 将意味着该值将在网格的所有设备上复制。
为了向后兼容,in_shardings 仍然支持接收
PartitionSpec
。此选项仅可与网格上下文管理器一起使用。PartitionSpec
,长度最多等于分区值秩的元组。每个元素可以是None
、网格轴或网格轴的元组,并指定分配给对值中与其在规范中的位置匹配的维度进行分区的资源集。
每个维度的大小必须是分配给它的资源总数的倍数。
out_shardings – 与
in_shardings
类似,但指定函数输出的资源分配。out_shardings
参数是可选的。如果未指定,jax.jit()
将使用 GSPMD 的分片传播来确定如何对输出进行分片。static_argnums (int | Sequence[int] | None | None) –
可选的 int 或 int 集合,用于指定将哪些位置参数视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。
静态参数应该是可哈希的,这意味着实现了
__hash__
和__eq__
,并且是不可变的。使用这些常量的不同值调用 jitted 函数将触发重新编译。不是数组或其容器的参数必须标记为静态。如果未提供
static_argnums
,则不会将任何参数视为静态。static_argnames (str | Iterable[str] | None | None) – 可选的字符串或字符串集合,用于指定将哪些命名参数视为静态(编译时常量)。有关详细信息,请参阅
static_argnums
上的注释。如果未提供但设置了static_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。donate_argnums ( int | Sequence[int] | None | None) –
指定哪些位置参数缓冲区“捐赠”给计算。如果计算完成后您不再需要这些参数缓冲区,则捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重复使用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将会引发错误。默认情况下,不捐赠任何参数缓冲区。
如果既没有提供
donate_argnums
也没有提供donate_argnames
,则不捐赠任何参数。如果未提供donate_argnums
但提供了donate_argnames
,或者反之亦然,JAX 将使用inspect.signature(fun)
来查找与donate_argnames
(或反之亦然)对应的任何位置参数。如果同时提供了donate_argnums
和donate_argnames
,则不会使用inspect.signature
,并且只会捐赠donate_argnums
或donate_argnames
中列出的实际参数。有关缓冲区捐赠的更多详细信息,请参阅 FAQ。
donate_argnames ( str | Iterable[str] | None | None) – 可选的字符串或字符串集合,指定哪些命名参数捐赠给计算。有关详细信息,请参阅关于
donate_argnums
的注释。如果未提供但设置了donate_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。keep_unused ( bool) – 如果为 False(默认值),JAX 确定 fun 未使用的参数可能会从生成的已编译 XLA 可执行文件中删除。此类参数将不会传输到设备或提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。
device (xc.Device | None | None) – 此参数已弃用。请在将参数传递给 jit 之前将它们放置在您想要的设备上。可选,jit 函数将在其上运行的设备。(可通过
jax.devices()
获取可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend ( str | None | None) – 此参数已弃用。请在将参数传递给 jit 之前将它们放置在您想要的后端上。可选,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。inline ( bool)
abstracted_axes (Any | None | None)
- 返回值:
fun
的包装版本,设置为即时编译,并由每次调用时可用的网格自动分区。- 返回类型:
JitWrapped
例如,可以通过单个
pjit()
应用程序将卷积算子自动分区到任意一组设备上>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import Mesh, PartitionSpec >>> from jax.experimental.pjit import pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_shardings=None, out_shardings=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2. 4. 6. 8. 10. 12. 10. ]