jax.numpy.multiply#
- jax.numpy.multiply = <jnp.ufunc 'multiply'>#
按元素相乘两个数组。
numpy.multiply
的 JAX 实现。这是一个通用函数,并支持在jax.numpy.ufunc
中描述的附加 API。此函数为 JAX 数组提供了*
运算符的实现。- 参数:
x – 要相乘的数组。必须可广播到公共形状。
y – 要相乘的数组。必须可广播到公共形状。
args (ArrayLike)
out (None)
where (None)
- 返回:
包含按元素相乘结果的数组。
- 返回类型:
Any
示例
显式调用
multiply
>>> x = jnp.arange(4) >>> jnp.multiply(x, 10) Array([ 0, 10, 20, 30], dtype=int32)
通过
*
运算符调用multiply
>>> x * 10 Array([ 0, 10, 20, 30], dtype=int32)