jax.numpy.linalg.inv#
- jax.numpy.linalg.inv(a)[source]#
返回方阵的逆矩阵
JAX 实现
numpy.linalg.inv()
.- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组,指定要求逆的方阵。- 返回:
形状为
(..., N, N)
的数组,包含输入的逆矩阵。- 返回类型:
备注
在大多数情况下,显式计算矩阵的逆矩阵是不明智的。例如,要计算
x = inv(A) @ b
,使用直接求解器(例如jax.scipy.linalg.solve()
)的性能更高且数值精度更高。参见
jax.scipy.linalg.inv()
:SciPy 风格的矩阵求逆 APIjax.numpy.linalg.solve()
:直接线性求解器
示例
计算 3x3 矩阵的逆矩阵
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jnp.linalg.inv(a) >>> a_inv Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
检查乘以逆矩阵是否得到单位矩阵
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
将逆矩阵乘以向量
b
,以找到a @ x = b
的解>>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32)
但是,请注意,在这种情况下显式计算逆矩阵会导致性能下降和精度损失,因为问题的规模会不断增大。 相反,您应该使用直接求解器,例如
jax.numpy.linalg.solve()
>>> jnp.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32)