jax.numpy.matvec#

jax.numpy.matvec(x1, x2, /)[源代码]#

批量矩阵向量积。

numpy.matvec() 的 JAX 实现。

参数:
  • x1 (ArrayLike) – 形状为 (..., M, N) 的数组

  • x2 (类数组) – 形状为 (..., N) 的数组。前导维度必须与 x1 的前导维度进行广播兼容。

返回:

一个形状为 (..., M) 的数组,包含批量的矩阵-向量积。

返回类型:

数组

另请参阅

示例

简单的矩阵-向量积

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)

批量的矩阵-向量积

>>> x2 = jnp.array([[7, 8, 9],
...                 [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
       [ 38,  92]], dtype=int32)