jax.numpy.nanargmax

内容

jax.numpy.nanargmax#

jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[source]#

返回指定轴上的最大值的索引,忽略

LAX 后端实现 numpy.nanargmax().

警告: jax.numpy.argmax 对全 NaN 片返回 -1,不会引发错误。

原始文档字符串如下。

NaNs. 对于全 NaN 片,会引发 ValueError. 警告: 如果片仅包含 NaN 和 -Infs,则结果不可信。

参数:
  • a (array_like) – 输入数据。

  • axis (int, optional) – 操作的轴。默认情况下使用扁平化输入。

  • keepdims (bool, optional) – 如果设置为 True,则缩减的轴在结果中保留为大小为一的维度。使用此选项,结果将与数组正确广播。

  • out (None | None)

返回值:

index_array – 索引数组或单个索引值。

返回值类型:

ndarray