jax.numpy.logaddexp#

jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#

计算 log(exp(x1) + exp(x2)),避免溢出。

numpy.logaddexp 的 JAX 实现

参数:
  • x1 – 输入数组

  • x2 – 输入数组

  • args (类数组)

  • out (None)

  • where (None)

返回值:

包含结果的数组。

返回类型:

Any

示例

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True