jax.numpy.logaddexp2#
- jax.numpy.logaddexp2(x1, x2, /)[source]#
以 2 为底的指数输入之和的对数,避免溢出。
JAX 实现
numpy.logaddexp2
。- 参数:
x1 (ArrayLike) – 输入数组或标量。
x2 (ArrayLike) – 输入数组或标量。
x1
和x2
应该具有相同的形状或广播兼容。
- 返回值:
包含结果的数组,\(log_2(2^{x1}+2^{x2})\),逐元素计算。
- 返回类型:
参见
jax.numpy.logaddexp()
:逐元素计算log(exp(x1) + exp(x2))
。jax.numpy.log2()
:逐元素计算x
的以 2 为底的对数。
示例
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> result1 = jnp.logaddexp2(x1, x2) >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2)) >>> jnp.allclose(result1, result2) Array(True, dtype=bool)