jax.numpy.expm1#
- jax.numpy.expm1(x, /)[源代码]#
计算输入的每个元素的
exp(x)-1
。JAX 实现了
numpy.expm1
。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个数组,包含
x
中每个元素的exp(x)-1
,并提升到非精确数据类型。- 返回类型:
注意
对于小的
x
值,jnp.expm1
比直接计算exp(x)-1
的精度高得多。另请参阅
jax.numpy.log1p()
: 计算输入加一后的逐元素对数。jax.numpy.exp()
: 计算输入的逐元素指数。jax.numpy.exp2()
: 计算输入的每个元素的 2 的指数。
示例
>>> x = jnp.array([2, -4, 3, -1]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.expm1(x)) [ 6.39 -0.98 19.09 -0.63] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x)-1) [ 6.39 -0.98 19.09 -0.63]
对于非常接近 0 的值,
jnp.expm1(x)
比jnp.exp(x)-1
准确得多>>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) >>> jnp.expm1(x1) Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32) >>> jnp.exp(x1)-1 Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)