jax.jit#
- jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[source]#
设置
fun
以使用 XLA 进行即时编译。- 参数:
fun (Callable) –
要进行 JIT 编译的函数。
fun
应该是一个纯函数。fun
的参数和返回值应该是数组、标量或它们的(嵌套)标准 Python 容器(元组/列表/字典)。由static_argnums
指示的位置参数可以是任何可哈希类型。静态参数会包含在编译缓存键中,这就是为什么必须定义哈希和相等运算符的原因。JAX 保留对fun
的弱引用,用作编译缓存键,因此对象fun
必须是可弱引用的。in_shardings – 可选,一个
Sharding
或一个具有Sharding
叶子且结构是fun
的位置参数元组的树前缀的 PyTree。如果提供,则传递给fun
的位置参数必须具有与in_shardings
兼容的分片,否则会引发错误,并且编译后的计算具有与in_shardings
对应的输入分片。如果未提供,则编译后的计算的输入分片将从参数分片中推断。out_shardings – 可选,一个
Sharding
或一个具有Sharding
叶子且结构是fun
输出的树前缀的 PyTree。如果提供,其效果与将相应的jax.lax.with_sharding_constraint`s 应用于 ``fun`()
的输出相同。static_argnums (int | Sequence[int] | None | None) –
可选,一个整数或整数集合,用于指定将哪些位置参数视为静态的(跟踪时和编译时常量)。
静态参数应该是可哈希的,这意味着
__hash__
和__eq__
都已实现,并且是不可变的。否则,它们可以是任意 Python 对象。使用这些常量的不同值调用 JIT 编译的函数将触发重新编译。不是类数组或其容器的参数必须标记为静态。如果未提供
static_argnums
也未提供static_argnames
,则不会将任何参数视为静态。如果未提供static_argnums
但提供了static_argnames
,或者反之,JAX 会使用inspect.signature(fun)
来查找与static_argnames
(或反之)对应的任何位置参数。如果同时提供了static_argnums
和static_argnames
,则不会使用inspect.signature
,并且只有static_argnums
或static_argnames
中列出的实际参数才会被视为静态。static_argnames (str | Iterable[str] | None | None) – 可选,一个字符串或字符串集合,用于指定将哪些命名参数视为静态的(编译时常量)。有关详细信息,请参阅关于
static_argnums
的注释。如果未提供但设置了static_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。donate_argnums (int | Sequence[int] | None | None) –
可选,整数集合,用于指定哪些位置参数缓冲区可以被计算覆盖,并在调用方中标记为已删除。如果您在计算开始后不再需要参数缓冲区,则可以安全地捐赠它们。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应该重复使用您捐赠给计算的缓冲区;如果您尝试这样做,JAX 将引发错误。默认情况下,不捐赠任何参数缓冲区。
如果未提供
donate_argnums
也未提供donate_argnames
,则不捐赠任何参数。如果未提供donate_argnums
但提供了donate_argnames
,或者反之,JAX 会使用inspect.signature(fun)
来查找与donate_argnames
(或反之)对应的任何位置参数。如果同时提供了donate_argnums
和donate_argnames
,则不会使用inspect.signature
,并且只有donate_argnums
或donate_argnames
中列出的实际参数才会被捐赠。有关缓冲区捐赠的更多详细信息,请参阅 FAQ。
donate_argnames (str | Iterable[str] | None | None) – 可选,一个字符串或字符串集合,用于指定将哪些命名参数捐赠给计算。有关详细信息,请参阅关于
donate_argnums
的注释。如果未提供但设置了donate_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。keep_unused (bool) – 可选的布尔值。如果为 False(默认值),则 JAX 确定 fun 未使用的参数可能会从生成的已编译 XLA 可执行文件中删除。此类参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。
device (xc.Device | None | None) – 这是一个实验性功能,API 可能会更改。可选,JIT 编译的函数将在其上运行的设备。(可以通过
jax.devices()
检索可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend ( str | None | None) – 这是一个实验性功能,API 很可能会发生变化。可选参数,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。inline ( bool) – 可选布尔值。指定此函数是否应内联到封闭的 jaxpr 中。默认为 False。
abstracted_axes (Any | None | None)
- 返回:
一个
fun
的包装版本,设置为即时编译。- 返回类型:
pjit.JitWrapped
示例
在以下示例中,
selu
可以通过 XLA 编译成单个融合内核>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
要在装饰函数时传递诸如
static_argnames
之类的参数,一种常见的模式是使用functools.partial()
>>> from functools import partial >>> >>> @partial(jax.jit, static_argnames=['n']) ... def g(x, n): ... for i in range(n): ... x = x ** 2 ... return x >>> >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32)