jax.numpy.subtract#

jax.numpy.subtract = <jnp.ufunc 'subtract'>#

逐元素减去两个数组。

numpy.subtract 的 JAX 实现。这是一个通用函数,并支持 jax.numpy.ufunc 中描述的附加 API。此函数为 JAX 数组提供 - 运算符的实现。

参数:
  • x – 要减去的数组。必须可广播到通用形状。

  • y – 要减去的数组。必须可广播到通用形状。

  • args (ArrayLike)

  • out (None)

  • where (None)

返回值:

包含逐元素减法结果的数组。

返回类型:

Any

示例

显式调用 subtract

>>> x = jnp.arange(4)
>>> jnp.subtract(x, 10)
Array([-10,  -9,  -8,  -7], dtype=int32)

通过 - 运算符调用 subtract

>>> x - 10
Array([-10,  -9,  -8,  -7], dtype=int32)