jax.numpy.where

内容

jax.numpy.where#

jax.numpy.where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...][source]#
jax.numpy.where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array
jax.numpy.where(condition: ArrayLike, x: ArrayLike | None = None, y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array | tuple[Array, ...]

根据条件从两个数组中选择元素。

JAX 实现 numpy.where().

注意

当只提供 condition 时,jnp.where(condition) 等效于 jnp.nonzero(condition)。对于这种情况,请参阅 jax.numpy.nonzero() 的文档。以下文档重点介绍了指定了 xy 的情况。

jnp.where 的三元版本降低到 jax.lax.select().

参数:
  • condition – 布尔型数组。当指定了 xy 时,必须与 xy 广播兼容。

  • x – 类数组。应与 conditiony 广播兼容,并与 y 类型转换兼容。

  • y – 类数组。应与 conditionx 广播兼容,并与 x 类型转换兼容。

  • size – 整数,仅在 xyNone 时引用。有关详细信息,请参阅 jax.numpy.nonzero().

  • fill_value – 仅在 xyNone 时引用。有关详细信息,请参阅 jax.numpy.nonzero().

返回值:

一个类型为 jnp.result_type(x, y) 的数组,其值在 condition 为 True 时从 x 中获取,在 condition 为 False 时从 y 中获取。如果 xyNone,则函数的行为不同;有关返回值类型的说明,请参阅 jax.numpy.nonzero().

笔记

jax.numpy.where()xy 输入可能包含 NaN 值时,需要特别注意。具体而言,当使用 jax.grad()(反向模式微分)获取梯度时,xy 中的 NaN 将传播到梯度中,而不管 condition 的值如何。有关此行为和解决方法的更多信息,请访问 JAX 常见问题解答.

示例

当未提供 xy 时,where 的行为等效于 jax.numpy.nonzero()

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

当提供 xy 时,where 会根据指定的条件在它们之间进行选择。

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)