jax.numpy.flatnonzero

目录

jax.numpy.flatnonzero#

jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#

返回扁平化数组中非零元素的索引

JAX 对 numpy.flatnonzero() 的实现。

jnp.flatnonzero(x) 等效于 nonzero(ravel(a))[0]。有关此函数参数的完整讨论,请参考 jax.numpy.nonzero()

参数:
  • a (ArrayLike) – N 维数组。

  • size (int | None | None) – 可选的静态整数,指定要返回的非零条目数。有关此参数的更多讨论,请参见 jax.numpy.nonzero()

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – 当指定 size 时,可选的填充值。默认为 0。有关此参数的更多讨论,请参见 jax.numpy.nonzero()

返回:

包含扁平化数组中每个非零值的索引的数组。

返回类型:

数组

示例

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)

这等效于对扁平化数组调用 nonzero(),并提取结果元组中的第一个条目。

>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)

返回的索引可用于从扁平化数组中提取非零条目。

>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)