jax.numpy.square#
- jax.numpy.square(x, /)[源代码]#
计算输入数组的逐元素平方。
numpy.square
的 JAX 实现。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个包含
x
元素平方的数组。- 返回类型:
注意
jnp.square
等价于计算jnp.power(x, 2)
。另请参阅
jax.numpy.sqrt()
:计算输入数组的逐元素非负平方根。jax.numpy.power()
:计算x2
的逐元素底数x1
指数。jax.lax.integer_pow()
: 计算逐元素幂 \(x^y\),其中 \(y\) 是一个固定的整数。jax.numpy.float_power()
: 计算第一个数组以第二个数组为指数的逐元素幂,通过提升到非精确数据类型来实现。
示例
>>> x = jnp.array([3, -2, 5.3, 1]) >>> jnp.square(x) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) >>> jnp.power(x, 2) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
对于整数输入
>>> x1 = jnp.array([2, 4, 5, 6]) >>> jnp.square(x1) Array([ 4, 16, 25, 36], dtype=int32)
对于复数值输入
>>> x2 = jnp.array([1-3j, -1j, 2]) >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)