jax.lax.scan#
- jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[源代码]#
在沿前导数组轴扫描函数的同时传递状态。
简而言之,类 Haskell 类型签名为
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
其中,对于任何数组类型说明符
t
,[t]
表示具有额外前导轴的类型,如果t
是具有数组叶子的 pytree(容器)类型,则[t]
表示具有相同 pytree 结构和相应叶子(每个叶子都具有额外前导轴)的类型。当
xs
的类型(在上面表示为 a)是数组类型或 None,并且ys
的类型(在上面表示为 b)是数组类型时,scan()
的语义大致由以下 Python 实现给出:def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
与该 Python 版本不同,
xs
和ys
都可以是任意的 pytree 值,因此可以同时扫描多个数组并生成多个输出数组。None
实际上是这种情况的一个特例,因为它表示一个空的 pytree。此外,与该 Python 版本不同,
scan()
是 JAX 的一个原语,它会被降级为一个 WhileOp。这使得它对于减少 JIT 编译函数的编译时间很有用,因为jit()
函数中的原生 Python 循环结构会被展开,从而导致大型的 XLA 计算。最后,循环携带的值
carry
在所有迭代中必须保持固定的形状和 dtype(而不仅仅是在 NumPy 的 rank/shape 广播和 dtype 提升规则中保持一致)。换句话说,上面类型签名中的类型c
表示具有固定形状和 dtype 的数组(或具有固定结构并在叶子处具有固定形状和 dtype 的数组的嵌套元组/列表/字典容器数据结构)。注意
scan()
会编译f
,因此,虽然它可以与jit()
结合使用,但通常是不必要的。- 参数:
f (Callable[[Carry, X], tuple[Carry, Y]]) – 要扫描的 Python 函数,类型为
c -> a -> (c, b)
,这意味着f
接受两个参数,其中第一个是循环携带的值,第二个是xs
沿其前导轴的切片,并且f
返回一个对,其中第一个元素表示循环携带的新值,第二个元素表示输出的切片。init (Carry) – 类型为
c
的初始循环携带值,可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典),表示初始循环携带值。此值必须与f
返回的对的第一个元素具有相同的结构。xs (X | None) – 类型为
[a]
的值,沿前导轴扫描,其中[a]
可以是数组或任何 pytree(嵌套的 Python 元组/列表/字典),并且具有一致的前导轴大小。length (int | None) – 可选整数,指定循环迭代的次数,该次数必须与
xs
中数组的前导轴大小一致(但可用于执行不需要任何输入xs
的扫描)。reverse (bool) – 可选布尔值,指定是向前运行扫描迭代(默认)还是反向运行,等效于反转
xs
和ys
中数组的前导轴。unroll (int | bool) – 可选正整数或布尔值,指定在扫描原语的底层操作中,在一个循环迭代中展开多少次扫描迭代。如果提供整数,则它确定在循环的单个滚动迭代中运行多少次展开的循环迭代。如果提供布尔值,它将确定循环是否完全展开(即 unroll=True)或完全不展开(即 unroll=False)。
_split_transpose (bool) – 可选的实验性布尔值,指定是否进一步将转置拆分为扫描(计算激活梯度)和映射(计算与数组参数对应的梯度)。启用此功能可能会增加内存需求,因此这是一项实验性功能,可能会发展甚至回滚。
- 返回:
类型为
(c, [b])
的对,其中第一个元素表示最终的循环携带值,第二个元素表示在扫描输入的前导轴时f
的第二个输出的堆叠输出。- 返回类型:
tuple[Carry, Y]