jax.numpy.flatnonzero#
- jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#
返回扁平化数组中非零元素的索引
JAX 对
numpy.flatnonzero()
的实现。jnp.flatnonzero(x)
等效于nonzero(ravel(a))[0]
。有关此函数参数的完整讨论,请参考jax.numpy.nonzero()
。- 参数:
a (ArrayLike) – N 维数组。
size (int | None | None) – 可选的静态整数,指定要返回的非零条目数。有关此参数的更多讨论,请参见
jax.numpy.nonzero()
。fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – 当指定
size
时,可选的填充值。默认为 0。有关此参数的更多讨论,请参见jax.numpy.nonzero()
。
- 返回:
包含扁平化数组中每个非零值的索引的数组。
- 返回类型:
示例
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 8]]) >>> jnp.flatnonzero(x) Array([1, 3, 5], dtype=int32)
这等效于对扁平化数组调用
nonzero()
,并提取结果元组中的第一个条目。>>> jnp.nonzero(x.ravel())[0] Array([1, 3, 5], dtype=int32)
返回的索引可用于从扁平化数组中提取非零条目。
>>> indices = jnp.flatnonzero(x) >>> x.ravel()[indices] Array([5, 6, 8], dtype=int32)