jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[源代码]#
通过归约为
jax.lax.while_loop()
,从lower
循环到upper
。简而言之,类似 Haskell 的类型签名为
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
fori_loop
的语义由以下 Python 实现给出:def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
正如 Python 版本所示,设置
upper <= lower
将不会产生任何迭代。不支持负数或自定义增量。与 Python 版本不同,
fori_loop
的实现依赖于调用jax.lax.while_loop()
或调用jax.lax.scan()
。如果循环次数是静态的(意味着在跟踪时已知,可能是因为lower
和upper
是 Python 整数文字),那么fori_loop
将使用scan()
实现,并支持反向模式自动微分;否则,将使用while_loop
,并且不支持反向模式自动微分。有关详细信息,请参阅这些函数的文档字符串。同样与 Python 版本不同,循环携带值
val
在所有迭代中必须保持固定的形状和数据类型(而不仅仅是在 NumPy 秩/形状广播和数据类型提升规则下保持一致)。换句话说,上面的类型签名中的类型a
表示具有固定形状和数据类型的数组(或具有固定结构和叶子处具有固定形状和数据类型的数组的嵌套元组/列表/字典容器数据结构)。注意
fori_loop()
会编译body_fun
,因此虽然它可以与jit()
结合使用,但通常是不必要的。- 参数:
lower – 表示循环索引下界(包含)的整数
upper – 表示循环索引上界(不包含)的整数
body_fun – 类型为
(int, a) -> a
的函数。init_val – 类型为
a
的初始循环携带值。unroll (int | bool | None) – 一个可选的整数或布尔值,用于确定循环展开的程度。如果提供整数,它将确定在循环的单个滚动迭代中运行多少次展开的循环迭代。如果提供布尔值,它将确定循环是否完全展开(即 unroll=True)或完全不展开(即 unroll=False)。此参数仅在循环边界静态已知时适用。
- 返回值:
来自最后一次迭代的循环值,类型为
a
。