jax.numpy.nonzero

内容

jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#

返回数组中非零元素的索引。

JAX 实现 numpy.nonzero().

由于 nonzero 输出的大小依赖于数据,因此该函数与 JIT 和其他转换不兼容。JAX 版本添加了可选的 size 参数,该参数必须静态指定才能在 JAX 的转换中使用 jnp.nonzero

参数:
  • a (ArrayLike) – N 维数组。

  • size (int | None | None) – 可选的静态整数,指定要返回的非零条目数量。如果非零元素的数量超过指定的 size,则索引将在末尾截断。如果非零元素的数量少于指定的 size,则索引将用 fill_value 填充,默认为零。

  • fill_value (None | ArrayLike | 元组[ArrayLike, ...] | None) – 当指定 size 时可选的填充值。默认为 0。

返回:

长度为 a.ndim 的 JAX 数组元组,包含每个非零值的索引。

返回类型:

元组[Array, …]

示例

一维数组返回长度为 1 的索引元组

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

二维数组返回长度为 2 的索引元组

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

在这两种情况下,生成的索引元组都可以直接用于提取非零值

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

nonzero 的输出具有动态形状,因为返回的索引数量取决于输入数组的内容。因此,它与 JIT 和其他 JAX 变换不兼容。

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

可以通过传递静态 size 参数来指定所需的输出形状来解决此问题。

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

如果 size 与真实大小不匹配,则结果将被截断或填充。

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

您可以使用 fill_value 参数指定填充的自定义填充值。

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)