jax.nn.softmax#

jax.nn.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[源代码]#

Softmax 函数。

计算将元素重新缩放到 \([0, 1]\) 范围的函数,使得沿 axis 的元素总和为 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
参数:
  • x (ArrayLike) – 输入数组

  • axis (int | tuple[int, ...] | None) – 应计算 softmax 的轴或轴。 在这些维度上求和的 softmax 输出应总和为 \(1\)。 可以是整数或整数元组。

  • where (ArrayLike | None | None) – 要包含在 softmax 中的元素。

  • 初始值 (未指定)

返回:

一个数组。

返回类型:

数组

注意

如果任何输入值是 +inf,结果将全部为 NaN:这反映了在浮点数学的上下文中 inf / inf 没有明确定义的的事实。

另请参阅

log_softmax()