jax.lax.argmax

内容

jax.lax.argmax#

jax.lax.argmax(operand, axis, index_dtype)[source]#

沿 axis 计算最大元素的索引。

参数:
  • operand (ArrayLike)

  • axis (int)

  • index_dtype (DTypeLike)

返回类型:

数组