jax.numpy.where#
- jax.numpy.where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...] [source]#
- jax.numpy.where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array
- jax.numpy.where(condition: ArrayLike, x: ArrayLike | None = None, y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array | tuple[Array, ...]
根据条件从两个数组中选择元素。
JAX 实现
numpy.where()
.注意
当只提供
condition
时,jnp.where(condition)
等效于jnp.nonzero(condition)
。对于这种情况,请参阅jax.numpy.nonzero()
的文档。以下文档重点介绍了指定了x
和y
的情况。jnp.where
的三元版本降低到jax.lax.select()
.- 参数:
condition – 布尔型数组。当指定了
x
和y
时,必须与x
和y
广播兼容。x – 类数组。应与
condition
和y
广播兼容,并与y
类型转换兼容。y – 类数组。应与
condition
和x
广播兼容,并与x
类型转换兼容。size – 整数,仅在
x
和y
为None
时引用。有关详细信息,请参阅jax.numpy.nonzero()
.fill_value – 仅在
x
和y
为None
时引用。有关详细信息,请参阅jax.numpy.nonzero()
.
- 返回值:
一个类型为
jnp.result_type(x, y)
的数组,其值在condition
为 True 时从x
中获取,在 condition 为False
时从y
中获取。如果x
和y
为None
,则函数的行为不同;有关返回值类型的说明,请参阅jax.numpy.nonzero()
.
笔记
当
jax.numpy.where()
的x
或y
输入可能包含 NaN 值时,需要特别注意。具体而言,当使用jax.grad()
(反向模式微分)获取梯度时,x
或y
中的 NaN 将传播到梯度中,而不管condition
的值如何。有关此行为和解决方法的更多信息,请访问 JAX 常见问题解答.示例
当未提供
x
和y
时,where
的行为等效于jax.numpy.nonzero()
>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
当提供
x
和y
时,where
会根据指定的条件在它们之间进行选择。>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)