jax.numpy.argmin#
- jax.numpy.argmin(a, axis=None, out=None, keepdims=None)[源代码]#
返回数组中最小值对应的索引。
numpy.argmin()
的 JAX 实现。- 参数:
a (类数组) – 输入数组
axis (int | None | None) – 可选整数,指定沿哪个轴查找最小值。如果未指定
axis
,则a
将被展平。out (None | None) – JAX 未使用。
keepdims (bool | None | None) – 如果为 True,则返回与
a
具有相同维度数的数组。
- 返回:
一个数组,包含沿指定轴的最小值的索引。
- 返回类型:
另请参阅
jax.numpy.argmax()
:返回最大值的索引。jax.numpy.nanargmin()
:计算argmin
,同时忽略 NaN 值。
示例
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmin(x) Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmin(x, axis=1) Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True) Array([[0], [2]], dtype=int32)