jax.numpy.argmax# jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[source]# 沿轴返回最大值的索引。 LAX 后端实现 numpy.argmax(). 原始文档字符串如下。 参数: a (array_like) – 输入数组。 axis (int, 可选) – 默认情况下,索引指向扁平化的数组,否则沿着指定的轴。 keepdims (bool, 可选) – 如果将其设置为 True,则减少的轴将作为大小为一的维度保留在结果中。使用此选项,结果将针对数组正确广播。 out (None | None) 返回值: index_array – 数组中索引的数组。它与 a.shape 形状相同,但去掉了沿着 axis 的维度。如果 keepdims 设置为 True,则 axis 的大小将为 1,结果数组与 a.shape 形状相同。 返回类型: 整数的 ndarray