jax.scipy.special.logsumexp#
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array [source]#
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]
对数-求和-指数约简。
JAX 实现
scipy.special.logsumexp()
.\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]其中 \(j\) 索引跨越要约简的一个或多个维度。
- 参数:
a – 输入数组
axis – 要约简的轴或轴集。可以是
None
、整数或整数元组。b – \(\mathrm{exp}(a)\) 的缩放因子。必须可广播到 a 的形状。
keepdims – 如果为
True
,则约简的轴将保留在输出中,作为大小为 1 的维度。return_sign – 如果为
True
,则输出将是(result, sign)
对,其中sign
是求和的符号,result
包含其绝对值的自然对数。如果为False
,则仅返回result
,如果求和为负,则它将包含 NaN 值。where – 要包含在约简中的元素。
- 返回值:
一个数组
result
或一对数组(result, sign)
,具体取决于return_sign
参数的值。