jax.numpy.nanargmin#

jax.numpy.nanargmin(a, axis=None, out=None, keepdims=None)[源代码]#

返回一个数组中最小值(忽略 NaN 值)的索引。

JAX 实现的 numpy.nanargmin()

参数:
  • a (ArrayLike) – 输入数组

  • axis (int | None | None) – 可选整数,指定沿哪个轴查找最小值。 如果未指定 axis,则 a 将被展平。

  • out (None | None) – JAX 未使用

  • keepdims (bool | None | None) – 如果为 True,则返回与 a 具有相同维度的数组。

返回值:

一个包含沿指定轴的最小值索引的数组。

返回类型:

Array

注意

如果一个轴上的值全部为 NaN,则返回的索引将为 -1。 这与 numpy.nanargmin() 的行为不同,后者会引发错误。

另请参阅

示例

>>> x = jnp.array([jnp.nan, 3, 5, 4, 2])
>>> jnp.nanargmin(x)
Array(4, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmin(x, axis=1)
Array([0, 1], dtype=int32)
>>> jnp.nanargmin(x, axis=1, keepdims=True)
Array([[0],
       [1]], dtype=int32)