jax.numpy.linalg.lstsq#
- jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[源代码]#
返回线性方程的最小二乘解。
numpy.linalg.lstsq()
的 JAX 实现。- 参数:
a (ArrayLike) – 形状为
(M, N)
的数组,表示系数矩阵。b (类数组) – 形状为
(M,)
或(M, K)
的数组,表示右侧。rcond (float | None | None) – 用于小奇异值的截断比率。小于
rcond * largest_singular_value
的奇异值将被视为零。如果为 None (默认),则将使用最佳值来减少浮点误差。numpy_resid (bool) – 如果为 True,则以与 NumPy 的 linalg.lstsq 相同的方式计算并返回残差。如果您想精确复制 NumPy 的行为,这是必要的。如果为 False (默认),则使用更高效的方法来计算残差。
- 返回:
数组的元组
(x, resid, rank, s)
,其中x
是一个形状为(N,)
或(N, K)
的数组,包含最小二乘解。resid
是形状为()
或(K,)
的平方残差之和。rank
是矩阵a
的秩。s
是矩阵a
的奇异值。
- 返回类型:
示例
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) >>> x, _, _, _ = jnp.linalg.lstsq(a, b) >>> with jnp.printoptions(precision=3): ... print(x) [-4. 4.5]