jax.numpy.add#

jax.numpy.add = <jnp.ufunc 'add'>#

对两个数组进行元素级相加。

numpy.add 的 JAX 实现。这是一个通用函数,并支持 jax.numpy.ufunc 中描述的额外 API。此函数提供 JAX 数组的 + 运算符的实现。

参数:
  • x – 要相加的数组。必须可广播为通用形状。

  • y – 要相加的数组。必须可广播为通用形状。

  • args (ArrayLike)

  • out (None)

  • where (None)

返回值:

包含元素级相加结果的数组。

返回类型:

Any

示例

显式调用 add

>>> x = jnp.arange(4)
>>> jnp.add(x, 10)
Array([10, 11, 12, 13], dtype=int32)

通过 + 运算符调用 add

>>> x + 10
Array([10, 11, 12, 13], dtype=int32)