jax.scipy.linalg.expm#
- jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[源代码]#
计算矩阵指数
scipy.linalg.expm()
的 JAX 实现。- 参数:
A (ArrayLike) – 形状为
(..., N, N)
的数组upper_triangular (bool) – 如果为 True,则假定
A
是上三角矩阵。默认值=False。max_squarings (int) – 缩放和平方近似方法中的平方次数(默认值:16)。
- 返回:
形状为
(..., N, N)
的数组,其中包含A
的矩阵指数。- 返回类型:
说明
这使用缩放和平方近似方法,其计算复杂度由可选的
max_squarings
参数控制。理论上,所需的平方次数为max(0, ceil(log2(norm(A))) - c)
,其中norm(A)
是 L1 范数,对于 float64/complex128,c=2.42
,对于 float32/complex64,c=1.97
。示例
expm
是矩阵指数,其性质与更熟悉的标量指数相似。对于标量a
和b
, \(e^{a + b} = e^a e^b\)。但是,对于矩阵,此属性仅在A
和B
可交换时(AB = BA
)成立。在这种情况下,expm(A+B) = expm(A) @ expm(B)
>>> A = jnp.array([[2, 0], ... [0, 1]]) >>> B = jnp.array([[3, 0], ... [0, 4]]) >>> jnp.allclose(jax.scipy.linalg.expm(A+B), ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), ... rtol=0.0001) Array(True, dtype=bool)
如果矩阵
X
是可逆的,则expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)
>>> X = jnp.array([[3, 1], ... [2, 5]]) >>> X_inv = jax.scipy.linalg.inv(X) >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), ... X @ jax.scipy.linalg.expm(A) @ X_inv) Array(True, dtype=bool)