jax.numpy.logaddexp2

内容

jax.numpy.logaddexp2#

jax.numpy.logaddexp2(x1, x2, /)[source]#

以 2 为底的指数输入之和的对数,避免溢出。

JAX 实现 numpy.logaddexp2

参数:
  • x1 (ArrayLike) – 输入数组或标量。

  • x2 (ArrayLike) – 输入数组或标量。 x1x2 应该具有相同的形状或广播兼容。

返回值:

包含结果的数组,\(log_2(2^{x1}+2^{x2})\),逐元素计算。

返回类型:

数组

参见

示例

>>> x1 = jnp.array([[3, -1, 4],
...                 [8, 5, -2]])
>>> x2 = jnp.array([2, 3, -5])
>>> result1 = jnp.logaddexp2(x1, x2)
>>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2))
>>> jnp.allclose(result1, result2)
Array(True, dtype=bool)