jax.numpy.linalg.lstsq

内容

jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[source]#

返回线性方程的最小二乘解。

JAX 实现 numpy.linalg.lstsq().

参数:
  • a (ArrayLike) – 形状为 (M, N) 的数组,表示系数矩阵。

  • b (ArrayLike) – 形状为 (M,)(M, K) 的数组,表示右侧。

  • rcond (float | None | None) – 小奇异值的截止比率。小于 rcond * largest_singular_value 的奇异值将被视为零。如果为 None(默认值),则将使用最佳值来减少浮点错误。

  • numpy_resid (bool) – 如果为 True,则以与 NumPy 的 linalg.lstsq 相同的方式计算并返回残差。如果您想精确复制 NumPy 的行为,则需要这样做。如果为 False(默认值),则使用更有效的方法来计算残差。

返回值:

数组元组 (x, resid, rank, s),其中

  • x 是一个形状为 (N,)(N, K) 的数组,包含最小二乘解。

  • resid 是残差平方和,形状为 ()(K,)

  • rank 是矩阵 a 的秩。

  • s 是矩阵 a 的奇异值。

返回类型:

元组[Array, Array, Array, Array]

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]