jax.scipy.linalg.inv#
- jax.scipy.linalg.inv(a, overwrite_a=False, check_finite=True)[源代码]#
返回一个方阵的逆矩阵
JAX实现的
scipy.linalg.inv()
。- 参数:
- 返回:
形状为
(..., N, N)
的数组,包含输入的逆矩阵。- 返回类型:
注意事项
在大多数情况下,显式计算矩阵的逆是不明智的。 例如,要计算
x = inv(A) @ b
,使用直接求解(例如jax.scipy.linalg.solve()
)会更高效且数值精度更高。另请参阅
jax.numpy.linalg.inv()
:用于矩阵求逆的 NumPy 风格 APIjax.scipy.linalg.solve()
:直接线性求解器
示例
计算 3x3 矩阵的逆矩阵
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jax.scipy.linalg.inv(a) >>> a_inv Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
检查与逆矩阵相乘是否得到单位矩阵
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
将逆矩阵乘以向量
b
,以找到a @ x = b
的解>>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32)
但是请注意,在这种情况下显式计算逆矩阵可能会导致性能不佳和精度损失,尤其是在问题规模增大时。 相反,您应该使用像
jax.scipy.linalg.solve()
这样的直接求解器>>> jax.scipy.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32)