jax.numpy.argwhere

内容

jax.numpy.argwhere#

jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]#

查找非零数组元素的索引

JAX 对 numpy.argwhere() 的实现。

jnp.argwhere(x) 本质上等效于 jnp.column_stack(jnp.nonzero(x)),并针对零维(即标量)输入进行特殊处理。

由于 argwhere 输出的大小取决于数据,因此该函数通常与 JIT 不兼容。JAX 版本添加了可选的 size 参数,它指定输出前导维的大小 - 必须静态指定它才能使 jnp.argwhere 编译非静态操作数。有关 size 及其语义的完整讨论,请参阅 jax.numpy.nonzero()

参数:
  • a (ArrayLike) – 用于查找非零元素的数组

  • size (int | None | None) – 指定预期非零元素数量的可选整数。为了在 JAX 变换中使用 argwhere(例如 jax.jit()),必须指定此参数。有关更多信息,请参阅 jax.numpy.nonzero()

  • fill_value (ArrayLike | None | None) – 当指定 size 时,指定填充值的可选数组。有关更多信息,请参阅 jax.numpy.nonzero()

返回值:

形状为 [size, x.ndim] 的二维数组。如果 size 未作为参数指定,则它等于 x 中非零元素的数量。

返回类型:

数组

示例

二维数组

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

使用 jax.numpy.column_stack()jax.numpy.nonzero() 的等效计算

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

零维(即标量)输入的特殊情况

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)