jax.numpy.sign

内容

jax.numpy.sign#

jax.numpy.sign(x, /)[source]#

返回输入的符号的逐元素指示。

numpy.sign 的 JAX 实现。

对于实值输入,x 的符号为

\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]

对于复数值输入,jnp.sign 返回一个表示相位的单位向量。对于广义情况,x 的符号由下式给出

\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]
参数:

x (ArrayLike) – 输入数组或标量。

返回值:

一个与 x 形状和数据类型相同的数组,包含符号指示。

返回类型:

Array

参见

示例

对于实值输入

>>> x = jnp.array([0., -3., 7.])
>>> jnp.sign(x)
Array([ 0., -1.,  1.], dtype=float32)

对于复数输入

>>> x1 = jnp.array([1, 3+4j, 5j])
>>> jnp.sign(x1)
Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)