jax.numpy.matvec#
- jax.numpy.matvec(x1, x2, /)[源代码]#
批量矩阵向量积。
numpy.matvec()
的 JAX 实现。- 参数:
x1 (ArrayLike) – 形状为
(..., M, N)
的数组x2 (类数组) – 形状为
(..., N)
的数组。前导维度必须与x1
的前导维度进行广播兼容。
- 返回:
一个形状为
(..., M)
的数组,包含批量的矩阵-向量积。- 返回类型:
另请参阅
jax.numpy.linalg.vecdot()
: 批量向量积。jax.numpy.vecmat()
: 向量-矩阵积。jax.numpy.matmul()
: 通用矩阵乘法。
示例
简单的矩阵-向量积
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([7, 8, 9]) >>> jnp.matvec(x1, x2) Array([ 50, 122], dtype=int32)
批量的矩阵-向量积
>>> x2 = jnp.array([[7, 8, 9], ... [5, 6, 7]]) >>> jnp.matvec(x1, x2) Array([[ 50, 122], [ 38, 92]], dtype=int32)