jax.scipy.linalg.sqrtm

内容

jax.scipy.linalg.sqrtm#

jax.scipy.linalg.sqrtm(A, blocksize=1)[source]#

计算矩阵平方根

JAX 实现 scipy.linalg.sqrtm().

参数:
  • A (ArrayLike) – 形状为 (N, N) 的数组

  • blocksize (int) – 在 JAX 中不支持;JAX 始终使用 blocksize=1

返回值:

形状为 (N, N) 的数组,包含 A 的矩阵平方根

返回值类型:

数组

示例

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> sqrt_a = jax.scipy.linalg.sqrtm(a)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(sqrt_a)
[[0.92+0.71j 0.54+0.j   0.92-0.71j]
 [0.54+0.j   1.85+0.j   0.54-0.j  ]
 [0.92-0.71j 0.54-0.j   0.92+0.71j]]

根据定义,矩阵平方根与自身的矩阵乘法应等于输入

>>> jnp.allclose(a, sqrt_a @ sqrt_a)
Array(True, dtype=bool)

注意

此函数实现了 [1] 中描述的复舒尔方法。它不使用递归分块来加速计算,因为 JAX 中还没有可用的西尔维斯特方程求解器。

参考文献