jax.lax.while_loop

内容

jax.lax.while_loop#

jax.lax.while_loop(cond_fun, body_fun, init_val)[source]#

cond_fun 为真时,重复调用 body_fun 以循环形式。

简要的 Haskell 式类型签名

while_loop :: (a -> Bool) -> (a -> a) -> a -> a

while_loop 的语义由以下 Python 实现给出

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

与该 Python 版本不同,while_loop 是 JAX 原语,并被降低为单个 WhileOp。这使得它对于减少 jit 编译函数的编译时间很有用,因为 @jit 函数中的原生 Python 循环结构会被展开,从而导致大型 XLA 计算。

与 Python 等价物不同,循环携带的值 val 必须在所有迭代中保持固定的形状和数据类型(而不仅仅是根据 NumPy 秩/形状广播和数据类型提升规则一致)。换句话说,上面类型签名中的类型 a 代表一个具有固定形状和数据类型的数组(或一个嵌套的元组/列表/字典容器数据结构,该结构具有固定的结构,并且在叶子处具有固定形状和数据类型的数组)。

与使用 Python 原生循环构造的另一个区别是 while_loop 不可反向模式微分,因为 XLA 计算需要对内存需求进行静态边界。

注意

while_loop() 会编译 cond_funbody_fun,因此尽管它可以与 jit() 结合使用,但通常没有必要。

参数:
  • cond_fun (Callable[[T], BooleanNumeric]) – 类型为 a -> Bool 的函数。

  • body_fun (Callable[[T], T]) – 类型为 a -> a 的函数。

  • init_val (T) – 类型为 a 的值,该类型可以是标量、数组或任何其 pytree(嵌套的 Python 元组/列表/字典),代表初始循环携带值。

返回值:

body_fun 最后一次迭代的输出,类型为 a

返回类型:

T