jax.numpy.linalg.pinv#
- jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[源代码]#
计算矩阵的(Moore-Penrose)伪逆。
JAX 实现的
numpy.linalg.pinv()
。- 参数:
a (ArrayLike) – 形状为
(..., M, N)
的数组,包含要伪逆的矩阵。rtol (ArrayLike | None | None) – float 或形状为
a.shape[:-2]
的 array_like。指定小奇异值的截止值。形状为(...,)
。小奇异值的截止值;小于rtol * largest_singular_value
的奇异值将被视为零。默认值根据 dtype 的浮点精度确定。hermitian (bool) – 如果为 True,则假设输入为 Hermitian 矩阵,并使用更高效的算法(默认值:False)
rcond (ArrayLike | DeprecatedArg | None) – 已弃用的
rtol
参数别名。如果使用,将导致DeprecationWarning
。
- 返回值:
一个形状为
(..., N, M)
的数组,包含a
的伪逆。- 返回类型:
另请参阅
jax.numpy.linalg.inv()
:方阵的乘法逆矩阵。
注释
jax.numpy.linalg.pinv()
与numpy.linalg.pinv()
在 rcond 的默认值上有所不同:在 NumPy 中,默认值为 1e-15。在 JAX 中,默认值为10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps
。示例
>>> a = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> a_pinv = jnp.linalg.pinv(a) >>> a_pinv Array([[-1.333332 , -0.33333257, 0.6666657 ], [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32)
只要输出不是秩亏的,伪逆就充当乘法逆
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool)