jax.numpy.linalg.pinv

内容

jax.numpy.linalg.pinv#

jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[source]#

计算矩阵的(Moore-Penrose)伪逆。

JAX 实现的 numpy.linalg.pinv()

参数:
  • a (ArrayLike) – 形状为 (..., M, N) 的数组,包含要进行伪逆的矩阵。

  • rtol (ArrayLike | None | None) – 浮点数或形状为 a.shape[:-2] 的类数组。指定小奇异值的截止值,形状为 (...,)。小奇异值的截止值;小于 rtol * largest_singular_value 的奇异值将被视为零。默认值根据 dtype 的浮点精度确定。

  • hermitian (bool) – 如果为 True,则假定输入为 Hermitian,并使用更有效的算法(默认值:False)

  • rcond (ArrayLike | DeprecatedArg | None) – rtol 参数的已弃用别名。如果使用,将导致 DeprecationWarning

返回:

形状为 (..., N, M) 的数组,包含 a 的伪逆。

返回类型:

数组

另请参阅

备注

jax.numpy.linalg.prng()numpy.linalg.prng()rcond` 的默认值方面有所不同:在 NumPy 中,默认值为 1e-15。在 JAX 中,默认值为 10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps

示例

>>> a = jnp.array([[1, 2],
...                [3, 4],
...                [5, 6]])
>>> a_pinv = jnp.linalg.pinv(a)
>>> a_pinv  
Array([[-1.333332  , -0.33333257,  0.6666657 ],
       [ 1.0833322 ,  0.33333272, -0.41666582]], dtype=float32)

只要输出不是秩亏损的,伪逆就会像乘法逆一样工作。

>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
Array(True, dtype=bool)