jax.numpy.sqrt#
- jax.numpy.sqrt(x, /)[源代码]#
计算输入数组的逐元素非负平方根。
JAX 对
numpy.sqrt
的实现。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个包含
x
元素非负平方根的数组。- 返回类型:
注意
对于实值负输入,
jnp.sqrt
会产生nan
输出。对于复值负输入,
jnp.sqrt
会产生complex
输出。
另请参阅
jax.numpy.square()
: 计算输入的逐元素平方。jax.numpy.power()
: 计算x2
的逐元素基数x1
指数。
示例
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)