jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[源代码]#
根据条件从两个数组中选择元素。
numpy.where()
的 JAX 实现。注意
当仅提供
condition
时,jnp.where(condition)
等价于jnp.nonzero(condition)
。对于这种情况,请参阅jax.numpy.nonzero()
的文档。下面的文档字符串侧重于指定x
和y
的情况。三项版本的
jnp.where
会降级为jax.lax.select()
。- 参数:
condition – 布尔数组。当指定
x
和y
时,必须与它们兼容广播。x – 类数组。应与
condition
和y
兼容广播,并且与y
类型转换兼容。y – 类数组。应与
condition
和x
兼容广播,并且与x
类型转换兼容。size – 整数,仅当
x
和y
为None
时引用。详情请参阅jax.numpy.nonzero()
。fill_value – 仅当
x
和y
为None
时引用。详情请参阅jax.numpy.nonzero()
。
- 返回:
一个 dtype 为
jnp.result_type(x, y)
的数组,其值从x
中提取(当condition
为 True 时),以及从y
中提取(当 condition 为False
时)。如果x
和y
为None
,则此函数的行为有所不同;请参阅jax.numpy.nonzero()
查看返回类型的描述。
笔记
当
jax.numpy.where()
的x
或y
输入可能包含 NaN 值时,需要特别注意。具体来说,当使用jax.grad()
(反向模式微分)进行梯度计算时,无论condition
的值如何,x
或y
中的 NaN 都将传播到梯度中。有关此行为和解决方法,请参阅 JAX FAQ。示例
当未提供
x
和y
时,where
的行为等同于jax.numpy.nonzero()
>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
当提供了
x
和y
时,where
会根据指定的条件在它们之间进行选择>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)