jax.numpy.where#

jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[源代码]#

根据条件从两个数组中选择元素。

numpy.where()的 JAX 实现。

注意

当仅提供 condition 时,jnp.where(condition) 等同于 jnp.nonzero(condition)。对于这种情况,请参阅 jax.numpy.nonzero() 的文档。下面的文档字符串重点介绍指定了 xy 的情况。

jnp.where 的三项版本会降级为 jax.lax.select()

参数:
  • condition – 布尔数组。当指定 xy 时,必须与它们进行广播兼容。

  • x – 类数组。应与 conditiony 进行广播兼容,并与 y 类型转换兼容。

  • y – 类数组。应与 conditionx 进行广播兼容,并与 x 类型转换兼容。

  • size – 整数,仅当 xyNone 时引用。有关详细信息,请参阅 jax.numpy.nonzero()

  • fill_value – 仅当 xyNone 时引用。有关详细信息,请参阅 jax.numpy.nonzero()

返回:

一个 dtype 为 jnp.result_type(x, y) 的数组,其值来自 x,当 condition 为 True 时,值来自 y,当 condition 为 False 时。如果 xyNone,则函数的行为不同;有关返回类型的描述,请参阅 jax.numpy.nonzero()

笔记

jax.numpy.where()xy 输入可能具有 NaN 值时,需要特别注意。具体来说,当使用 jax.grad() (反向模式微分) 获取梯度时,无论 condition 的值如何,xy 中的 NaN 都会传播到梯度中。有关此行为和解决方法,请参阅 JAX 常见问题解答

示例

当不提供 xy 时,where 的行为等同于 jax.numpy.nonzero()

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

当提供 xy 时,where 会根据指定的条件在它们之间进行选择

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)