jax.numpy.argmin#

jax.numpy.argmin(a, axis=None, out=None, keepdims=None)[源代码]#

返回数组中最小值对应的索引。

numpy.argmin() 的 JAX 实现。

参数
  • a (类数组) – 输入数组

  • axis (int | None | None) – 可选整数,指定沿哪个轴查找最小值。如果未指定 axis,则 a 将被展平。

  • out (None | None) – JAX 未使用。

  • keepdims (bool | None | None) – 如果为 True,则返回与 a 具有相同维度数的数组。

返回:

一个数组,包含沿指定轴的最小值的索引。

返回类型:

数组

另请参阅

示例

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmin(x)
Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmin(x, axis=1)
Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True)
Array([[0],
       [2]], dtype=int32)