jax.lax.select_n#
- jax.lax.select_n(which, *cases)[源代码]#
从多个情况中选择数组值。
推广了 XLA 的 Select 操作符。与 XLA 的版本不同,该操作符是可变的,并且可以使用整数 pred 从多种情况中选择。
- 参数:
which (ArrayLike) – 确定应返回哪种情况。必须是包含布尔值或整数值的数组。可以是标量,也可以具有与
cases
匹配的形状。对于每个数组元素,which
的值决定取cases
中的哪一个。which
必须在范围[0 .. len(cases))
内;对于该范围之外的值,其行为是实现定义的。*cases (ArrayLike) – 非空的数组情况列表。所有情况必须具有相同的数据类型和相同的形状。
- 返回:
一个数组,其形状和数据类型与情况相同,其值根据
which
选择。- 返回类型: