jax.numpy.sqrt#

jax.numpy.sqrt(x, /)[源代码]#

计算输入数组的逐元素非负平方根。

JAX 对 numpy.sqrt 的实现。

参数:

x (ArrayLike) – 输入数组或标量。

返回:

一个包含 x 元素非负平方根的数组。

返回类型:

数组

注意

  • 对于实值负输入,jnp.sqrt 会产生 nan 输出。

  • 对于复值负输入,jnp.sqrt 会产生 complex 输出。

另请参阅

示例

>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sqrt(x)
Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)