jax.jit#

jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[源代码]#

设置 fun 以使用 XLA 进行即时编译。

参数:
  • fun (Callable) –

    要进行 jit 编译的函数。fun 应该是一个纯函数。

    fun 的参数和返回值应该是数组、标量或它们的(嵌套)标准 Python 容器(元组/列表/字典)。由 static_argnums 指示的位置参数可以是任何可哈希类型。静态参数包含在编译缓存键中,这就是为什么必须定义哈希和相等运算符的原因。JAX 保留对 fun 的弱引用,以用作编译缓存键,因此对象 fun 必须是弱可引用的。

  • in_shardings – 可选,一个 Sharding 或带有 Sharding 叶子的 pytree,其结构是传递给 fun 的位置参数元组的树前缀。如果提供,则传递给 fun 的位置参数必须具有与 in_shardings 兼容的分片,否则会引发错误,并且编译后的计算具有与 in_shardings 对应的输入分片。如果未提供,则编译后的计算的输入分片将从参数分片中推断出来。

  • out_shardings – 可选,一个 Sharding 或带有 Sharding 叶子的 pytree,其结构是 fun 的输出的树前缀。如果提供,它的效果与对 ``fun`() 的输出应用相应的 jax.lax.with_sharding_constraint`s 相同。

  • static_argnums (int | Sequence[int] | None | None) –

    可选,一个 int 或一组 int,用于指定将哪些位置参数视为静态(跟踪和编译时常量)。

    静态参数应该是可哈希的,这意味着实现了 __hash____eq__,并且是不可变的。否则,它们可以是任意 Python 对象。使用不同的常量值调用 jit 编译的函数将触发重新编译。非类数组或其容器的参数必须标记为静态。

    如果既未提供 static_argnums 也未提供 static_argnames,则不会将任何参数视为静态。如果未提供 static_argnums 但提供了 static_argnames,反之亦然,则 JAX 使用 inspect.signature(fun) 来查找与 static_argnames 相对应的任何位置参数(反之亦然)。如果同时提供了 static_argnumsstatic_argnames,则不使用 inspect.signature,并且只有 static_argnumsstatic_argnames 中列出的实际参数才会被视为静态。

  • static_argnames (str | Iterable[str] | None | None) – 可选,一个字符串或一组字符串,用于指定将哪些命名参数视为静态(编译时常量)。有关详细信息,请参阅关于 static_argnums 的注释。如果未提供,但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums (int | Sequence[int] | None | None) –

    可选,一组整数,用于指定哪些位置参数缓冲区可以被计算覆盖并在调用者中标记为已删除。如果计算开始后您不再需要参数缓冲区,则捐赠参数缓冲区是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重用捐赠给计算的缓冲区;如果您尝试这样做,JAX 将引发错误。默认情况下,不捐赠任何参数缓冲区。

    如果既未提供 donate_argnums 也未提供 donate_argnames,则不会捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,则 JAX 使用 inspect.signature(fun) 来查找与 donate_argnames 相对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,并且只有 donate_argnumsdonate_argnames 中列出的实际参数才会被捐赠。

    有关缓冲区捐赠的更多详细信息,请参阅 FAQ

  • donate_argnames (str | Iterable[str] | None | None) – 可选,一个字符串或一组字符串,用于指定将哪些命名参数捐赠给计算。有关详细信息,请参阅关于 donate_argnums 的注释。如果未提供,但设置了 donate_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • keep_unused (bool) – 可选布尔值。如果为 False (默认值),则 JAX 确定 fun 未使用的参数可能会从生成的编译 XLA 可执行文件中删除。此类参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。

  • device (xc.Device | None | None) – 这是一个实验性功能,API 很可能会更改。可选,jit 编译的函数将在其上运行的设备。(可以通过 jax.devices() 检索可用的设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend (str | None | None) – 这是一个实验性功能,API 很可能会更改。可选,表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • inline (bool) – 可选布尔值。指定是否应将此函数内联到封闭的 jaxpr 中。默认为 False。

  • abstracted_axes (Any | None | None)

  • compiler_options (dict[str, Any] | None | None)

返回:

一个为即时编译设置的 fun 的包装版本。

返回类型:

pjit.JitWrapped

示例

在以下示例中,selu 可以被 XLA 编译成一个单独的融合内核

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))  
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

当装饰函数时要传递诸如 static_argnames 等参数,一个常见的模式是使用 functools.partial()

>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)