jax.numpy.logaddexp#
- jax.numpy.logaddexp(x1, x2, /)[source]#
计算
log(exp(x1) + exp(x2))
避免溢出。JAX 实现
numpy.logaddexp
- 参数:
x1 (ArrayLike) – 输入数组
x2 (ArrayLike) – 输入数组
- 返回值:
包含结果的数组。
- 返回类型:
示例
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True