jax.numpy.square#

jax.numpy.square(x, /)[源代码]#

计算输入数组的逐元素平方。

numpy.square 的 JAX 实现。

参数:

x (ArrayLike) – 输入数组或标量。

返回:

一个包含 x 元素平方的数组。

返回类型:

数组

注意

jnp.square 等价于计算 jnp.power(x, 2)

另请参阅

示例

>>> x = jnp.array([3, -2, 5.3, 1])
>>> jnp.square(x)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)
>>> jnp.power(x, 2)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)

对于整数输入

>>> x1 = jnp.array([2, 4, 5, 6])
>>> jnp.square(x1)
Array([ 4, 16, 25, 36], dtype=int32)

对于复数值输入

>>> x2 = jnp.array([1-3j, -1j, 2])
>>> jnp.square(x2)
Array([-8.-6.j, -1.+0.j,  4.+0.j], dtype=complex64)