jax.lax.scan#

jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[源代码]#

在领先的数组轴上扫描函数,同时携带状态。

简而言之,类似 Haskell 的类型签名

scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

其中对于任何数组类型说明符 t[t] 表示具有附加前导轴的类型,如果 t 是具有数组叶的 pytree(容器)类型,则 [t] 表示具有相同 pytree 结构和对应叶子的类型,每个叶子都具有附加的前导轴。

xs 的类型(上面用 a 表示)是数组类型或 None,并且 ys 的类型(上面用 b 表示)是数组类型时,scan() 的语义大致由以下 Python 实现给出:

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

与该 Python 版本不同,xsys 都可以是任意的 pytree 值,因此可以同时扫描多个数组并生成多个输出数组。None 实际上是这种情况的一个特例,因为它表示一个空的 pytree。

同样与该 Python 版本不同,scan() 是一个 JAX 原语,会被降级为单个 WhileOp。这使其对于减少 JIT 编译函数的编译时间很有用,因为 jit() 函数中的原生 Python 循环结构会被展开,从而导致大型的 XLA 计算。

最后,循环携带值 carry 必须在所有迭代中保持固定的形状和数据类型(而不仅仅是在 NumPy 秩/形状广播和数据类型提升规则下保持一致)。换句话说,上面类型签名中的类型 c 表示具有固定形状和数据类型的数组(或具有固定结构和叶子节点上具有固定形状和数据类型的数组的嵌套元组/列表/字典容器数据结构)。

注意

scan() 会编译 f,因此虽然它可以与 jit() 结合使用,但通常是不必要的。

参数:
  • f (Callable[[Carry, X], tuple[Carry, Y]]) – 一个要扫描的 Python 函数,类型为 c -> a -> (c, b),意味着 f 接受两个参数,其中第一个是循环携带的值,第二个是 xs 沿其主轴的切片,并且 f 返回一个对,其中第一个元素表示循环携带的新值,第二个元素表示输出的切片。

  • init (Carry) – 一个类型为 c 的初始循环携带值,它可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典),表示初始循环携带值。该值必须与 f 返回的对的第一个元素具有相同的结构。

  • xs (X | None) – 要沿主轴扫描的值,类型为 [a],其中 [a] 可以是一个数组或任何 pytree(嵌套的 Python 元组/列表/字典),并且具有一致的主轴大小。

  • length (int | None) – 可选的整数,指定循环迭代的次数,该次数必须与 xs 中数组的主轴大小一致(但可用于执行不需要输入 xs 的扫描)。

  • reverse (bool) – 可选的布尔值,指定是正向运行扫描迭代(默认)还是反向运行扫描迭代,相当于反转 xsys 中数组的主轴。

  • unroll (int | bool) – 可选的正整数或布尔值,指定在扫描原语的底层操作中,在一个循环迭代中展开多少扫描迭代。如果提供一个整数,它将决定在一个循环的单个滚动迭代中运行多少展开的循环迭代。如果提供一个布尔值,它将确定循环是完全展开(即 unroll=True)还是完全不展开(即 unroll=False)。

  • _split_transpose (bool) – 可选的实验性布尔值,指定是否将转置进一步拆分为扫描(计算激活梯度)和映射(计算对应于数组参数的梯度)。启用此功能可能会增加内存需求,因此这是一个实验性功能,可能会发展甚至被撤回。

返回:

类型为 (c, [b]) 的对,其中第一个元素表示最终的循环携带值,第二个元素表示在输入的主轴上扫描时 f 的第二个输出的堆叠输出。

返回类型:

tuple[Carry, Y]