jax.lax.select

内容

jax.lax.select#

jax.lax.select(pred, on_true, on_false)[source]#

根据布尔谓词在两个分支之间进行选择。

包装 XLA 的 Select 操作符。

通常,select() 会导致两个分支都被评估,尽管编译器可能会在可能的情况下省略计算。对于通常只评估一个分支的类似函数,请参阅 cond()

参数:
  • pred (ArrayLike) – 布尔数组

  • on_true (ArrayLike) – 数组,包含在 pred 为 True 时要返回的条目。必须与 pred 具有相同的形状,并且与 on_false 具有相同的形状和数据类型。

  • on_false (ArrayLike) – 数组,包含在 pred 为 False 时要返回的条目。必须与 pred 具有相同的形状,并且与 on_true 具有相同的形状和数据类型。

返回值:

on_trueon_false 具有相同形状和数据类型的数组。

返回类型:

result