jax.numpy.nonzero#
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#
返回数组中非零元素的索引。
JAX 实现
numpy.nonzero()
.由于
nonzero
输出的大小依赖于数据,因此该函数与 JIT 和其他转换不兼容。JAX 版本添加了可选的size
参数,该参数必须静态指定才能在 JAX 的转换中使用jnp.nonzero
。- 参数:
- 返回:
长度为
a.ndim
的 JAX 数组元组,包含每个非零值的索引。- 返回类型:
示例
一维数组返回长度为 1 的索引元组
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
二维数组返回长度为 2 的索引元组
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
在这两种情况下,生成的索引元组都可以直接用于提取非零值
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
nonzero
的输出具有动态形状,因为返回的索引数量取决于输入数组的内容。因此,它与 JIT 和其他 JAX 变换不兼容。>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
可以通过传递静态
size
参数来指定所需的输出形状来解决此问题。>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
如果
size
与真实大小不匹配,则结果将被截断或填充。>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
您可以使用
fill_value
参数指定填充的自定义填充值。>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)