jax.lax.argmax# jax.lax.argmax(operand, axis, index_dtype)[source]# 沿 axis 计算最大元素的索引。 参数: operand (ArrayLike) axis (int) index_dtype (DTypeLike) 返回类型: 数组