jax.numpy.sign#
- jax.numpy.sign(x, /)[源代码]#
返回输入元素的符号指示。
JAX 实现的
numpy.sign
。对于实值输入,
x
的符号为\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]对于复值输入,
jnp.sign
返回一个表示相位的单位向量。对于一般情况,x
的符号由下式给出\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个与
x
具有相同形状和 dtype 的数组,其中包含符号指示。- 返回类型:
另请参阅
jax.numpy.positive()
: 返回输入的元素级正值。jax.numpy.negative()
: 返回输入的元素级负值。
示例
对于实值输入
>>> x = jnp.array([0., -3., 7.]) >>> jnp.sign(x) Array([ 0., -1., 1.], dtype=float32)
对于复值输入
>>> x1 = jnp.array([1, 3+4j, 5j]) >>> jnp.sign(x1) Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)