jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[源代码]#

返回线性方程的最小二乘解。

numpy.linalg.lstsq() 的 JAX 实现。

参数:
  • a (ArrayLike) – 形状为 (M, N) 的数组,表示系数矩阵。

  • b (类数组) – 形状为 (M,)(M, K) 的数组,表示右侧。

  • rcond (float | None | None) – 用于小奇异值的截断比率。小于 rcond * largest_singular_value 的奇异值将被视为零。如果为 None (默认),则将使用最佳值来减少浮点误差。

  • numpy_resid (bool) – 如果为 True,则以与 NumPy 的 linalg.lstsq 相同的方式计算并返回残差。如果您想精确复制 NumPy 的行为,这是必要的。如果为 False (默认),则使用更高效的方法来计算残差。

返回:

数组的元组 (x, resid, rank, s),其中

  • x 是一个形状为 (N,)(N, K) 的数组,包含最小二乘解。

  • resid 是形状为 ()(K,) 的平方残差之和。

  • rank 是矩阵 a 的秩。

  • s 是矩阵 a 的奇异值。

返回类型:

tuple[Array, Array, Array, Array]

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]