jax.lax.scan

内容

jax.lax.scan#

jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[source]#

在保留状态的同时,沿着前导数组轴扫描函数。

简而言之,Haskell 式类型签名

scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

其中,对于任何数组类型说明符t[t]表示具有额外前导轴的类型,如果t是具有数组叶节点的pytree(容器)类型,则[t]表示具有相同pytree结构的类型,并且每个叶节点都具有一个额外的前导轴。

xs(上面表示为a)的类型是数组类型或None,并且ys(上面表示为b)的类型是数组类型时,scan()的语义由以下Python实现大致给出

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

与该Python版本不同,xsys都可以是任意pytree值,因此可以同时扫描多个数组并产生多个输出数组。None实际上是这种情况的一个特例,因为它表示一个空的pytree。

同样与该Python版本不同,scan()是一个JAX基本运算,并被降低为单个WhileOp。这使得它对于减少JIT编译函数的编译时间很有用,因为jit()函数中的原生Python循环结构会被展开,从而导致大型XLA计算。

最后,循环携带值carry必须在所有迭代中保持固定的形状和数据类型(而不仅仅是在NumPy秩/形状广播和数据类型提升规则下保持一致,例如)。换句话说,上面类型签名中的类型c表示具有固定形状和数据类型的数组(或具有固定结构的嵌套元组/列表/字典容器数据结构,以及叶节点处具有固定形状和数据类型的数组)。

注意

scan()会编译f,因此虽然它可以与jit()结合使用,但通常没有必要。

参数:
  • f (Callable[[Carry, X], tuple[Carry, Y]]) – 要扫描的Python函数,类型为c -> a -> (c, b),这意味着f接受两个参数,第一个参数是循环携带的值,第二个参数是xs沿其前导轴的切片,并且f返回一个对,其中第一个元素表示循环携带的新值,第二个元素表示输出的切片。

  • init (Carry) – 类型为c的初始循环携带值,可以是标量、数组或任何pytree(嵌套的Python元组/列表/字典),表示初始循环携带值。此值必须与f返回的对的第一个元素具有相同的结构。

  • xs (X | None) – 类型为[a]的值,沿其前导轴扫描,其中[a]可以是数组或任何pytree(嵌套的Python元组/列表/字典),并且具有一致的前导轴大小。

  • length (int | None) – 可选整数,指定循环迭代次数,必须与xs中数组的前导轴大小一致(但可用于执行不需要任何输入xs的扫描)。

  • reverse (bool) – 可选布尔值,指定是向前运行扫描迭代(默认值)还是向后运行,相当于反转xsys中数组的前导轴。

  • unroll (int | bool) – 可选正整数或布尔值,在扫描基本运算的底层操作中,指定在循环的单个迭代中展开多少个扫描迭代。如果提供整数,则它确定在循环的单个卷积迭代中运行多少个展开的循环迭代。如果提供布尔值,它将确定循环是否完全展开(即unroll=True)或完全不展开(即unroll=False)。

  • _split_transpose (bool) – 实验性可选布尔值,指定是否将转置进一步拆分为扫描(计算激活梯度)和映射(计算对应于数组参数的梯度)。启用此功能可能会增加内存需求,因此这是一个实验性功能,可能会发展甚至回滚。

返回值:

类型为(c, [b])的对,其中第一个元素表示最终的循环携带值,第二个元素表示当在输入的前导轴上扫描时,f的第二个输出的堆叠输出。

返回类型:

tuple[Carry, Y]