jax.numpy.argwhere#

jax.numpy.argwhere(a, *, size=None, fill_value=None)[源代码]#

查找数组中非零元素的索引

numpy.argwhere() 的 JAX 实现。

jnp.argwhere(x) 本质上等同于 jnp.column_stack(jnp.nonzero(x)),但对零维(即标量)输入进行了特殊处理。

由于 argwhere 的输出大小依赖于数据,因此该函数通常与 JIT 不兼容。JAX 版本添加了可选的 size 参数,用于指定输出的前导维度的大小 - 必须静态指定,才能使 jnp.argwhere 与非静态操作数一起编译。 有关 size 及其语义的完整讨论,请参阅 jax.numpy.nonzero()

参数:
  • a (ArrayLike) – 要查找非零元素的数组

  • size (int | None | None) – 可选整数,静态指定预期的非零元素数量。 必须指定此项,才能在 JAX 变换(如 jax.jit())中使用 argwhere。 有关更多信息,请参阅 jax.numpy.nonzero()

  • fill_value (ArrayLike | None | None) – 可选数组,指定指定 size 时的填充值。 有关更多信息,请参阅 jax.numpy.nonzero()

返回:

形状为 [size, x.ndim] 的二维数组。 如果未将 size 指定为参数,则它等于 x 中的非零元素数。

返回类型:

数组

示例

二维数组

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

使用 jax.numpy.column_stack()jax.numpy.nonzero() 的等效计算

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

零维(即标量)输入的特殊情况

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)