jax.numpy.logaddexp#
- jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#
计算
log(exp(x1) + exp(x2))
,避免溢出。numpy.logaddexp
的 JAX 实现- 参数:
x1 – 输入数组
x2 – 输入数组
args (类数组)
out (None)
where (None)
- 返回值:
包含结果的数组。
- 返回类型:
Any
示例
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True