jax.scipy.linalg.inv#
- jax.scipy.linalg.inv(a, overwrite_a=False, check_finite=True)[source]#
返回方阵的逆
JAX 实现
scipy.linalg.inv()
.- 参数::
- 返回::
形状为
(..., N, N)
的数组,包含输入的逆。- 返回类型::
备注
在大多数情况下,显式计算矩阵的逆运算是不明智的。例如,要计算
x = inv(A) @ b
,使用直接求解方法,例如jax.scipy.linalg.solve()
,性能更高,数值精度也更高。另请参阅
jax.numpy.linalg.inv()
: NumPy 风格的矩阵逆运算 APIjax.scipy.linalg.solve()
: 直接线性求解器
示例
计算 3x3 矩阵的逆运算
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jax.scipy.linalg.inv(a) >>> a_inv Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
检查与逆矩阵相乘是否得到单位矩阵
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
将逆矩阵乘以向量
b
,以找到a @ x = b
的解>>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32)
但是,请注意,在这种情况下,显式计算逆矩阵可能会导致随着问题规模的增长而出现性能下降和精度损失。相反,您应该使用直接求解器,例如
jax.scipy.linalg.solve()
>>> jax.scipy.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32)