jax.numpy.ldexp#

jax.numpy.ldexp(x1, x2, /)[源代码]#

计算 x1 * 2 ** x2

numpy.ldexp() 的 JAX 实现。

请注意,XLA 不提供 ldexp 操作,因此在 JAX 中通过标准乘法和求幂来实现此操作。

参数:
  • x1 (ArrayLike) – 实值输入数组。

  • x2 (ArrayLike) – 整数输入数组。必须与 x1 广播兼容。

返回:

x1 * 2 ** x2 按元素计算。

返回类型:

数组

另请参阅

示例

>>> x1 = jnp.arange(5.0)
>>> x2 = 10
>>> jnp.ldexp(x1, x2)
Array([   0., 1024., 2048., 3072., 4096.], dtype=float32)

ldexp 可用于重建 frexp 的输入

>>> x = jnp.array([2., 3., 5., 11.])
>>> m, e = jnp.frexp(x)
>>> m
Array([0.5   , 0.75  , 0.625 , 0.6875], dtype=float32)
>>> e
Array([2, 2, 3, 4], dtype=int32)
>>> jnp.ldexp(m, e)
Array([ 2.,  3.,  5., 11.], dtype=float32)