jax.numpy.subtract#
- jax.numpy.subtract = <jnp.ufunc 'subtract'>#
对两个数组进行逐元素相减。
JAX 对
numpy.subtract
的实现。这是一个通用函数,支持在jax.numpy.ufunc
中描述的额外 API。此函数提供了 JAX 数组的-
运算符的实现。- 参数:
x – 要相减的数组。必须可广播为公共形状。
y – 要相减的数组。必须可广播为公共形状。
args (类似数组)
out (None)
where (None)
- 返回值:
包含逐元素相减结果的数组。
- 返回类型:
任意类型
示例
显式调用
subtract
>>> x = jnp.arange(4) >>> jnp.subtract(x, 10) Array([-10, -9, -8, -7], dtype=int32)
通过
-
运算符调用subtract
>>> x - 10 Array([-10, -9, -8, -7], dtype=int32)