jax.numpy.power#

jax.numpy.power(x1, x2, /)[源代码]#

计算元素级的 x1x2 次方。

JAX 对 numpy.power 的实现。

参数:
  • x1 (ArrayLike) – 标量或数组。指定底数。

  • x2 (ArrayLike) – 标量或数组。指定指数。x1x2 应该具有相同的形状或可以进行广播兼容。

返回:

一个数组,包含 x1x2 次方,其数据类型与输入相同。

返回类型:

数组

注意

  • x2 是一个具体的整数标量时,jnp.power 会降级为 jax.lax.integer_pow()

  • x2 是一个跟踪标量或一个数组时,jnp.power 会降级为 jax.lax.pow()

  • jnp.power 对于整数类型取负整数次方会引发 TypeError

  • jnp.power 对于负值取非整数次方会返回 nan

另请参阅

  • jax.lax.pow():计算元素级的幂运算,\(x^y\)

  • jax.lax.integer_pow():计算元素级的幂运算 \(x^y\),其中 \(y\) 是一个固定的整数。

  • jax.numpy.float_power():通过提升到非精确数据类型,计算第一个数组的第二个数组次方的元素级运算。

  • jax.numpy.pow():计算第一个数组的第二个数组次方的元素级运算。

示例

具有标量整数的输入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

具有相同形状的输入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

具有广播兼容性的输入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)