jax.numpy.vecmat#

jax.numpy.vecmat(x1, x2, /)[源代码]#

批量共轭向量-矩阵乘积。

JAX 中 numpy.vecmat() 的实现。

参数:
  • x1 (ArrayLike) – 形状为 (..., M) 的数组。

  • x2 (ArrayLike) – 形状为 (..., M, N) 的数组。前导维度必须与 x1 的前导维度广播兼容。

返回:

形状为 (..., N) 的数组,包含批量共轭向量-矩阵乘积。

返回类型:

数组

另请参阅

示例

简单的向量-矩阵乘积

>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
...                 [6, 7],
...                 [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)

批量向量-矩阵乘积

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40,  46],
       [ 94, 109]], dtype=int32)