jax.scipy.special.logsumexp#

jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array[源代码]#
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]

对数求和指数规约。

scipy.special.logsumexp()的 JAX 实现。

\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]

其中 \(j\) 索引范围涵盖一个或多个要规约的维度。

参数:
  • a – 输入数组

  • axis – 要规约的轴或多个轴。可以是 None、int 或 int 元组。

  • b\(\mathrm{exp}(a)\) 的缩放因子。必须可广播为 a 的形状。

  • keepdims – 如果为 True,则会将规约的轴保留在输出中,作为大小为 1 的维度。

  • return_sign – 如果为 True,则输出将为 (result, sign) 对,其中 sign 是总和的符号,result 包含其绝对值的对数。如果为 False,则仅返回 result,并且如果总和为负数,它将包含 NaN 值。

  • where – 要包含在规约中的元素。

返回:

根据 return_sign 参数的值,返回数组 result 或数组对 (result, sign)