jax.lax.switch#
- jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#
根据
index
应用branches
中的恰好一个。如果
index
超出范围,则将其钳制在范围内。具有以下 Python 语义
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
在内部,这包装了 XLA 的 条件 运算符。但是,当使用
vmap()
进行转换以对谓词批次进行操作时,cond
将转换为select()
。- 参数::
index – 整数标量类型,指示要应用哪个分支函数。
branches (Sequence[Callable]) – 基于
index
要应用的函数序列 (A -> B)。operands – 应用哪个分支的运算符 (A) 输入。
- 返回值::
根据
index
选择的分支的branch(*operands)
的值 (B)。