jax.scipy.special.softmax

内容

jax.scipy.special.softmax#

jax.scipy.special.softmax(x, /, *, axis=None)[source]#

Softmax 函数。

JAX 对 scipy.special.softmax() 的实现。

计算将元素重新缩放到范围 \([0, 1]\) 的函数,使得沿 axis 的元素之和为 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
参数::
  • x (ArrayLike) – 输入数组

  • axis (int | tuple[int, ...] | None | None) – 应计算 softmax 的轴或轴集。跨这些维度求和的 softmax 输出应加起来为 \(1\)

返回值::

一个与 x 形状相同的数组。

返回类型::

Array

注意

如果任何输入值是 +inf,结果将全部为 NaN:这反映了 inf / inf 在浮点数学环境中没有定义的事实。

另请参阅

log_softmax()