jax.lax.select_n

内容

jax.lax.select_n#

jax.lax.select_n(which, *cases)[source]#

从多个案例中选择数组值。

概括了 XLA 的 Select 运算符。与 XLA 的版本不同,该运算符是可变的,可以使用整数 pred 从许多案例中进行选择。

参数:
  • which (ArrayLike) – 确定应返回哪个案例。必须是一个包含布尔值或整数值的数组。可以是标量,也可以具有与 cases 匹配的形状。对于每个数组元素, which 的值决定了采用 cases 中的哪一个。 which 必须在范围 [0 .. len(cases)) 内;对于超出该范围的值,行为是实现定义的。

  • *cases (类数组) – 一个非空的数组案例列表。所有案例必须具有相同的 dtype 和相同的形状。

返回:

一个与案例具有相同形状和 dtype 的数组,其值根据 which 选择。

返回类型:

数组