jax.lax.switch

内容

jax.lax.switch#

jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#

根据 index 应用 branches 中的恰好一个。

如果 index 超出范围,则将其钳制在范围内。

具有以下 Python 语义

def switch(index, branches, *operands):
  index = clamp(0, index, len(branches) - 1)
  return branches[index](*operands)

在内部,这包装了 XLA 的 条件 运算符。但是,当使用 vmap() 进行转换以对谓词批次进行操作时,cond 将转换为 select()

参数::
  • index – 整数标量类型,指示要应用哪个分支函数。

  • branches (Sequence[Callable]) – 基于 index 要应用的函数序列 (A -> B)。

  • operands – 应用哪个分支的运算符 (A) 输入。

返回值::

根据 index 选择的分支的 branch(*operands) 的值 (B)。