jax.numpy.power

内容

jax.numpy.power#

jax.numpy.power(x1, x2, /)[source]#

计算逐元素基数 x1 的指数 x2

JAX 实现 numpy.power.

参数:
  • x1 (ArrayLike) – 标量或数组。指定基数。

  • x2 (ArrayLike) – 标量或数组。指定指数。 x1x2 应具有相同的形状或广播兼容。

返回:

一个数组,包含与输入具有相同数据类型的 x2 的基数 x1 指数。

返回类型:

数组

注意

  • x2 是一个具体的整数标量时,jnp.power 会降级为 jax.lax.integer_pow()

  • x2 是一个被追踪的标量或数组时,jnp.power 会降级为 jax.lax.pow()

  • jnp.power 会对整数类型取负整数次方时抛出 TypeError

  • jnp.power 会对负值取非整数次方时返回 nan

另请参阅

示例

具有标量整数的输入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

具有相同形状的输入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

具有广播兼容性的输入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)