jax.numpy.subtract#

jax.numpy.subtract = <jnp.ufunc 'subtract'>#

对两个数组进行逐元素相减。

JAX 对 numpy.subtract 的实现。这是一个通用函数,支持在 jax.numpy.ufunc 中描述的额外 API。此函数提供了 JAX 数组的 - 运算符的实现。

参数:
  • x – 要相减的数组。必须可广播为公共形状。

  • y – 要相减的数组。必须可广播为公共形状。

  • args (类似数组)

  • out (None)

  • where (None)

返回值:

包含逐元素相减结果的数组。

返回类型:

任意类型

示例

显式调用 subtract

>>> x = jnp.arange(4)
>>> jnp.subtract(x, 10)
Array([-10,  -9,  -8,  -7], dtype=int32)

通过 - 运算符调用 subtract

>>> x - 10
Array([-10,  -9,  -8,  -7], dtype=int32)