jax.lax.cond#
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[source]#
有条件地应用
true_fun
或false_fun
。封装 XLA 的 Conditional 运算符。
在提供参数类型正确的情况下,
cond()
与以下 Python 实现具有相同的语义,其中pred
必须是标量类型def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
与
jax.lax.select()
相比,使用cond
表示只执行两个分支中的一个(直到编译器重写和优化)。但是,当使用vmap()
转换为对一批谓词进行操作时,cond
会转换为select()
。- 参数:
pred – 布尔标量类型,指示要应用哪个分支函数。
true_fun (Callable) – 函数 (A -> B),如果
pred
为 True,则应用此函数。false_fun (Callable) – 函数 (A -> B),如果
pred
为 False,则应用此函数。operands – 操作数 (A) 输入到任一分支,具体取决于
pred
。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。
- 返回值:
根据
pred
的值,返回true_fun(*operands)
或false_fun(*operands)
的值 (B)。类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)。