jax.numpy.copysign

内容

jax.numpy.copysign#

jax.numpy.copysign(x1, x2, /)[source]#

x2 中每个元素的符号复制到 x1 中的对应元素。

JAX 实现 numpy.copysign.

参数:
  • x1 (ArrayLike) – 输入数组

  • x2 (ArrayLike) – 用于确定符号的数组,必须与 x1 广播兼容

返回值:

包含 x1 的可能已更改元素的数组对象,始终提升为不精确数据类型,并且形状为 jnp.broadcast_shapes(x1.shape, x2.shape)

返回类型:

数组

示例

>>> x1 = jnp.array([5, 2, 0])
>>> x2 = -1
>>> jnp.copysign(x1, x2)
Array([-5., -2., -0.], dtype=float32)
>>> x1 = jnp.array([6, 8, 0])
>>> x2 = 2
>>> jnp.copysign(x1, x2)
Array([6., 8., 0.], dtype=float32)
>>> x1 = jnp.array([2, -3])
>>> x2 = jnp.array([[1],[-4], [5]])
>>> jnp.copysign(x1, x2)
Array([[ 2.,  3.],
       [-2., -3.],
       [ 2.,  3.]], dtype=float32)