jax.numpy.select#
- jax.numpy.select(condlist, choicelist, default=0)[源代码]#
根据一系列条件选择值。
JAX 实现了
numpy.select()
,它基于jax.lax.select_n()
实现。- 参数:
condlist (Sequence[ArrayLike]) – 类数组条件的序列。所有条目必须相互广播兼容。
choicelist (Sequence[ArrayLike]) – 要选择的类数组值的序列。必须与
condlist
的长度相同,并且所有条目必须与condlist
的条目广播兼容。default (ArrayLike) – 当所有条件均为 False 时返回的值(默认值:0)。
- 返回值:
从
choicelist
中选择的值的数组,对应于每个位置在condlist
中第一个为True
的条目。- 返回类型:
另请参阅
jax.numpy.where()
: 基于单个条件在两个值之间进行选择。jax.lax.select_n()
: 基于索引在 N 个值之间进行选择。
示例
>>> condlist = [ ... jnp.array([False, True, False, False]), ... jnp.array([True, False, False, False]), ... jnp.array([False, True, True, False]), ... ] >>> choicelist = [ ... jnp.array([1, 2, 3, 4]), ... jnp.array([10, 20, 30, 40]), ... jnp.array([100, 200, 300, 400]), ... ] >>> jnp.select(condlist, choicelist, default=0) Array([ 10, 2, 300, 0], dtype=int32)
这在逻辑上等同于以下嵌套的
where
语句>>> default = 0 >>> jnp.where(condlist[0], ... choicelist[0], ... jnp.where(condlist[1], ... choicelist[1], ... jnp.where(condlist[2], ... choicelist[2], ... default))) Array([ 10, 2, 300, 0], dtype=int32)
但是,为了效率,它基于
jax.lax.select_n()
实现。