jax.numpy.select

内容

jax.numpy.select#

jax.numpy.select(condlist, choicelist, default=0)[source]#

根据一系列条件选择值。

JAX 实现 numpy.select(),使用 jax.lax.select_n() 实现。

参数:
  • condlist (Sequence[ArrayLike]) – 数组类条件的序列。所有条目都必须是相互广播兼容的。

  • choicelist (Sequence[ArrayLike]) – 要选择的数组类值的序列。必须与 condlist 长度相同,并且所有条目都必须与 condlist 的条目广播兼容。

  • default (ArrayLike) – 当每个条件都为 False 时返回的值(默认值:0)。

返回值:

来自 choicelist 的选定值的数组,对应于每个位置 condlist 中的第一个 True 条目。

返回类型:

数组

另请参阅

示例

>>> condlist = [
...    jnp.array([False, True, False, False]),
...    jnp.array([True, False, False, False]),
...    jnp.array([False, True, True, False]),
... ]
>>> choicelist = [
...    jnp.array([1, 2, 3, 4]),
...    jnp.array([10, 20, 30, 40]),
...    jnp.array([100, 200, 300, 400]),
... ]
>>> jnp.select(condlist, choicelist, default=0)
Array([ 10,   2, 300,   0], dtype=int32)

这在逻辑上等效于以下嵌套的 where 语句

>>> default = 0
>>> jnp.where(condlist[0],
...   choicelist[0],
...   jnp.where(condlist[1],
...     choicelist[1],
...     jnp.where(condlist[2],
...       choicelist[2],
...       default)))
Array([ 10,   2, 300,   0], dtype=int32)

但是,为了提高效率,它是使用jax.lax.select_n()实现的。