jax.numpy.power#
- jax.numpy.power(x1, x2, /)[source]#
计算逐元素基数
x1
的指数x2
。JAX 实现
numpy.power
.- 参数:
x1 (ArrayLike) – 标量或数组。指定基数。
x2 (ArrayLike) – 标量或数组。指定指数。
x1
和x2
应具有相同的形状或广播兼容。
- 返回:
一个数组,包含与输入具有相同数据类型的
x2
的基数x1
指数。- 返回类型:
注意
当
x2
是一个具体的整数标量时,jnp.power
会降级为jax.lax.integer_pow()
。当
x2
是一个被追踪的标量或数组时,jnp.power
会降级为jax.lax.pow()
。jnp.power
会对整数类型取负整数次方时抛出TypeError
。jnp.power
会对负值取非整数次方时返回nan
。
另请参阅
jax.lax.pow()
: 计算逐元素幂运算,\(x^y\)。jax.lax.integer_pow()
: 计算逐元素幂运算 \(x^y\),其中 \(y\) 是一个固定整数。jax.numpy.float_power()
: 通过提升到非精确数据类型,计算第一个数组对第二个数组的逐元素幂运算。jax.numpy.pow()
: 计算第一个数组对第二个数组的逐元素幂运算。
示例
具有标量整数的输入
>>> jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
具有相同形状的输入
>>> x1 = jnp.array([2, 4, 5]) >>> x2 = jnp.array([3, 0.5, 2]) >>> jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
具有广播兼容性的输入
>>> x3 = jnp.array([-2, 3, 1]) >>> x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) >>> jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)