jax.lax.cond

内容

jax.lax.cond#

jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[source]#

有条件地应用 true_funfalse_fun

封装 XLA 的 Conditional 运算符。

在提供参数类型正确的情况下,cond() 与以下 Python 实现具有相同的语义,其中 pred 必须是标量类型

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

jax.lax.select() 相比,使用 cond 表示只执行两个分支中的一个(直到编译器重写和优化)。但是,当使用 vmap() 转换为对一批谓词进行操作时,cond 会转换为 select()

参数:
  • pred – 布尔标量类型,指示要应用哪个分支函数。

  • true_fun (Callable) – 函数 (A -> B),如果 pred 为 True,则应用此函数。

  • false_fun (Callable) – 函数 (A -> B),如果 pred 为 False,则应用此函数。

  • operands – 操作数 (A) 输入到任一分支,具体取决于 pred。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。

返回值:

根据 pred 的值,返回 true_fun(*operands)false_fun(*operands) 的值 (B)。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。