jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#
从
lower
循环到upper
,通过降维到jax.lax.while_loop()
。简而言之,Haskell 式类型签名 为
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
The semantics of
fori_loop
are given by this Python implementationdef fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
如 Python 版本所示,设置
upper <= lower
将不会产生任何迭代。不支持负数或自定义增量。与 Python 版本不同,
fori_loop
是通过调用jax.lax.while_loop()
或jax.lax.scan()
来实现的。如果循环次数是静态的(意味着在追踪时已知,可能是因为lower
和upper
是 Python 整数字面量),则fori_loop
是基于scan()
实现的,并且支持反向模式自动微分;否则,将使用while_loop
,并且不支持反向模式自动微分。有关更多信息,请参阅这些函数的文档字符串。同样与 Python 等价物不同,循环携带值
val
必须在所有迭代中保持固定的形状和数据类型(而不仅仅是根据 NumPy 秩/形状广播和数据类型提升规则保持一致)。换句话说,上面类型签名中的类型a
代表一个具有固定形状和数据类型的数组(或一个嵌套的元组/列表/字典容器数据结构,该结构具有固定的结构,并且叶子节点处的数组具有固定形状和数据类型)。注意
fori_loop()
会编译body_fun
,因此虽然它可以与jit()
组合使用,但通常没有必要。- 参数:
- 返回值:
来自最终迭代的循环值,类型为
a
。