jax.numpy.linalg.matrix_rank

内容

jax.numpy.linalg.matrix_rank#

jax.numpy.linalg.matrix_rank(M, rtol=None, *, tol=Deprecated)[source]#

计算矩阵的秩。

JAX 实现 numpy.linalg.matrix_rank().

秩通过奇异值分解 (SVD) 计算,并由大于指定容差的奇异值的数量决定。

参数:
  • M (ArrayLike) – 形状为 (..., N, K) 的数组,其秩将被计算。

  • rtol (ArrayLike | None) – 形状为 (...) 的可选数组,指定容差。小于 rtol * largest_singular_value 的奇异值被认为是零。如果 rtol 为 None(默认值),则根据输入的浮点精度选择一个合理的默认值。

  • tol (ArrayLike | DeprecatedArg | None) – rtol 参数的已弃用别名。如果使用,将导致 DeprecationWarning

返回:

形状为 a.shape[-2] 的数组,给出矩阵的秩。

返回类型:

数组

笔记

对于具有非常小的奇异值或数值上病态的矩阵,秩计算可能不准确。在这种情况下,请考虑调整 rtol 参数或使用更专业的秩计算方法。

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.matrix_rank(a)
Array(2, dtype=int32)
>>> b = jnp.array([[1, 0],  # Rank-deficient matrix
...                [0, 0]])
>>> jnp.linalg.matrix_rank(b)
Array(1, dtype=int32)