jax.numpy.logaddexp

内容

jax.numpy.logaddexp#

jax.numpy.logaddexp(x1, x2, /)[source]#

计算 log(exp(x1) + exp(x2)) 避免溢出。

JAX 实现 numpy.logaddexp

参数:
  • x1 (ArrayLike) – 输入数组

  • x2 (ArrayLike) – 输入数组

返回值:

包含结果的数组。

返回类型:

Array

示例

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True