jax.numpy.vecmat#
- jax.numpy.vecmat(x1, x2, /)[源代码]#
批量共轭向量-矩阵乘积。
JAX 中
numpy.vecmat()
的实现。- 参数:
x1 (ArrayLike) – 形状为
(..., M)
的数组。x2 (ArrayLike) – 形状为
(..., M, N)
的数组。前导维度必须与x1
的前导维度广播兼容。
- 返回:
形状为
(..., N)
的数组,包含批量共轭向量-矩阵乘积。- 返回类型:
另请参阅
jax.numpy.linalg.vecdot()
:批量向量积。jax.numpy.matvec()
:矩阵-向量积。jax.numpy.matmul()
:通用矩阵乘法。
示例
简单的向量-矩阵乘积
>>> x1 = jnp.array([[1, 2, 3]]) >>> x2 = jnp.array([[4, 5], ... [6, 7], ... [8, 9]]) >>> jnp.vecmat(x1, x2) Array([[40, 46]], dtype=int32)
批量向量-矩阵乘积
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.vecmat(x1, x2) Array([[ 40, 46], [ 94, 109]], dtype=int32)