jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[源代码]#

返回数组中非零元素的索引。

JAX实现的 numpy.nonzero()

由于 nonzero 的输出大小取决于数据,该函数与 JIT 和其他转换不兼容。JAX 版本添加了可选的 size 参数,必须静态指定该参数,才能在 JAX 的转换中使用 jnp.nonzero

参数:
  • a (ArrayLike) – N 维数组。

  • size (int | None | None) – 可选的静态整数,指定要返回的非零条目的数量。如果非零元素的数量多于指定的 size,则索引将在末尾被截断。如果非零元素的数量少于指定的 size,则索引将用 fill_value 填充,默认为零。

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – 当指定 size 时,可选的填充值。默认为 0。

返回:

长度为 a.ndim 的 JAX 数组的元组,包含每个非零值的索引。

返回类型:

tuple[Array, …]

示例

一维数组返回长度为 1 的索引元组

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

二维数组返回长度为 2 的索引元组

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

在任何一种情况下,生成的索引元组都可以直接用于提取非零值

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

由于返回索引的数量取决于输入数组的内容,因此 nonzero 的输出具有动态形状。 因此,它与 JIT 和其他 JAX 转换不兼容

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

可以通过传递静态的 size 参数来指定所需的输出形状来解决此问题

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

如果 size 与真实大小不匹配,则结果将被截断或填充

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

您可以使用 fill_value 参数为填充指定自定义填充值

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)